1010from mypy .checkmember import analyze_member_access
1111from mypy .expandtype import expand_type_by_instance
1212from mypy .join import join_types
13- from mypy .literals import literal_hash
13+ from mypy .literals import Key , literal_hash
1414from mypy .maptype import map_instance_to_supertype
1515from mypy .meet import narrow_declared_type
1616from mypy .messages import MessageBuilder
17- from mypy .nodes import ARG_POS , Context , Expression , NameExpr , TempNode , TypeAlias , Var
17+ from mypy .nodes import (
18+ ARG_POS ,
19+ Context ,
20+ Expression ,
21+ IndexExpr ,
22+ IntExpr ,
23+ ListExpr ,
24+ MemberExpr ,
25+ NameExpr ,
26+ TempNode ,
27+ TupleExpr ,
28+ TypeAlias ,
29+ UnaryExpr ,
30+ Var ,
31+ )
1832from mypy .options import Options
1933from mypy .patterns import (
2034 AsPattern ,
@@ -98,10 +112,8 @@ class PatternChecker(PatternVisitor[PatternType]):
98112 msg : MessageBuilder
99113 # Currently unused
100114 plugin : Plugin
101- # The expression being matched against the pattern
102- subject : Expression
103-
104- subject_type : Type
115+ # The expressions being matched against the (sub)pattern
116+ subject_context : list [list [Expression ]]
105117 # Type of the subject to check the (sub)pattern against
106118 type_context : list [Type ]
107119 # Types that match against self instead of their __match_args__ if used as a class pattern
@@ -120,24 +132,28 @@ def __init__(
120132 self .msg = msg
121133 self .plugin = plugin
122134
135+ self .subject_context = []
123136 self .type_context = []
124137 self .self_match_types = self .generate_types_from_names (self_match_type_names )
125138 self .non_sequence_match_types = self .generate_types_from_names (
126139 non_sequence_match_type_names
127140 )
128141 self .options = options
129142
130- def accept (self , o : Pattern , type_context : Type ) -> PatternType :
143+ def accept (self , o : Pattern , type_context : Type , subject : list [Expression ]) -> PatternType :
144+ self .subject_context .append (subject )
131145 self .type_context .append (type_context )
132146 result = o .accept (self )
147+ self .subject_context .pop ()
133148 self .type_context .pop ()
134149
135150 return result
136151
137152 def visit_as_pattern (self , o : AsPattern ) -> PatternType :
153+ current_subject = self .subject_context [- 1 ]
138154 current_type = self .type_context [- 1 ]
139155 if o .pattern is not None :
140- pattern_type = self .accept (o .pattern , current_type )
156+ pattern_type = self .accept (o .pattern , current_type , current_subject )
141157 typ , rest_type , type_map = pattern_type
142158 else :
143159 typ , rest_type , type_map = current_type , UninhabitedType (), {}
@@ -152,14 +168,15 @@ def visit_as_pattern(self, o: AsPattern) -> PatternType:
152168 return PatternType (typ , rest_type , type_map )
153169
154170 def visit_or_pattern (self , o : OrPattern ) -> PatternType :
171+ current_subject = self .subject_context [- 1 ]
155172 current_type = self .type_context [- 1 ]
156173
157174 #
158175 # Check all the subpatterns
159176 #
160- pattern_types = []
177+ pattern_types : list [ PatternType ] = []
161178 for pattern in o .patterns :
162- pattern_type = self .accept (pattern , current_type )
179+ pattern_type = self .accept (pattern , current_type , current_subject )
163180 pattern_types .append (pattern_type )
164181 if not is_uninhabited (pattern_type .type ):
165182 current_type = pattern_type .rest_type
@@ -175,28 +192,42 @@ def visit_or_pattern(self, o: OrPattern) -> PatternType:
175192 #
176193 # Check the capture types
177194 #
178- capture_types : dict [Var , list [tuple [Expression , Type ]]] = defaultdict (list )
195+ capture_types : dict [Var , dict [Key | None , list [tuple [Expression , Type ]]]] = defaultdict (
196+ lambda : defaultdict (list )
197+ )
198+ capture_expr_keys : set [Key | None ] = set ()
179199 # Collect captures from the first subpattern
180200 for expr , typ in pattern_types [0 ].captures .items ():
181- node = get_var (expr )
182- capture_types [node ].append ((expr , typ ))
201+ if (node := get_var (expr )) is None :
202+ continue
203+ key = literal_hash (expr )
204+ capture_types [node ][key ].append ((expr , typ ))
205+ if isinstance (expr , NameExpr ):
206+ capture_expr_keys .add (key )
183207
184208 # Check if other subpatterns capture the same names
185209 for i , pattern_type in enumerate (pattern_types [1 :]):
186- vars = {get_var (expr ) for expr , _ in pattern_type .captures .items ()}
187- if capture_types .keys () != vars :
210+ vars = {
211+ literal_hash (expr ) for expr in pattern_type .captures if isinstance (expr , NameExpr )
212+ }
213+ if capture_expr_keys != vars :
214+ # Only fail for directly captured names (with NameExpr)
188215 self .msg .fail (message_registry .OR_PATTERN_ALTERNATIVE_NAMES , o .patterns [i ])
189216 for expr , typ in pattern_type .captures .items ():
190- node = get_var (expr )
191- capture_types [node ].append ((expr , typ ))
217+ if (node := get_var (expr )) is None :
218+ continue
219+ key = literal_hash (expr )
220+ capture_types [node ][key ].append ((expr , typ ))
192221
193222 captures : dict [Expression , Type ] = {}
194- for capture_list in capture_types .values ():
195- typ = UninhabitedType ()
196- for _ , other in capture_list :
197- typ = make_simplified_union ([typ , other ])
223+ for expressions in capture_types .values ():
224+ for key , capture_list in expressions .items ():
225+ if other_types := [entry [1 ] for entry in capture_list ]:
226+ typ = make_simplified_union (other_types )
227+ else :
228+ typ = UninhabitedType ()
198229
199- captures [capture_list [0 ][0 ]] = typ
230+ captures [capture_list [0 ][0 ]] = typ
200231
201232 union_type = make_simplified_union (types )
202233 return PatternType (union_type , current_type , captures )
@@ -289,12 +320,37 @@ def visit_sequence_pattern(self, o: SequencePattern) -> PatternType:
289320 contracted_inner_types = self .contract_starred_pattern_types (
290321 inner_types , star_position , required_patterns
291322 )
292- for p , t in zip (o .patterns , contracted_inner_types ):
293- pattern_type = self .accept (p , t )
323+ current_subjects : list [list [Expression ]] = [[] for _ in range (len (contracted_inner_types ))]
324+ end_pos = len (contracted_inner_types ) if star_position is None else star_position
325+ for subject in self .subject_context [- 1 ]:
326+ if isinstance (subject , (ListExpr , TupleExpr )):
327+ # For list and tuple expressions, lookup expression in items
328+ for i in range (end_pos ):
329+ if i < len (subject .items ):
330+ current_subjects [i ].append (subject .items [i ])
331+ if star_position is not None :
332+ for i in range (star_position + 1 , len (contracted_inner_types )):
333+ offset = len (contracted_inner_types ) - i
334+ if offset <= len (subject .items ):
335+ current_subjects [i ].append (subject .items [- offset ])
336+ else :
337+ # Support x[0], x[1], ... lookup until wildcard
338+ for i in range (end_pos ):
339+ current_subjects [i ].append (IndexExpr (subject , IntExpr (i )))
340+ # For everything after wildcard use x[-2], x[-1]
341+ for i in range ((star_position or - 1 ) + 1 , len (contracted_inner_types )):
342+ offset = len (contracted_inner_types ) - i
343+ current_subjects [i ].append (IndexExpr (subject , UnaryExpr ("-" , IntExpr (offset ))))
344+ for p , t , s in zip (o .patterns , contracted_inner_types , current_subjects ):
345+ pattern_type = self .accept (p , t , s )
294346 typ , rest , type_map = pattern_type
295347 contracted_new_inner_types .append (typ )
296348 contracted_rest_inner_types .append (rest )
297349 self .update_type_map (captures , type_map )
350+ if s :
351+ self .update_type_map (
352+ captures , {subject : typ for subject in s }, fail_multiple_assignments = False
353+ )
298354
299355 new_inner_types = self .expand_starred_pattern_types (
300356 contracted_new_inner_types , star_position , len (inner_types ), unpack_index is not None
@@ -478,11 +534,18 @@ def visit_mapping_pattern(self, o: MappingPattern) -> PatternType:
478534 if inner_type is None :
479535 can_match = False
480536 inner_type = self .chk .named_type ("builtins.object" )
481- pattern_type = self .accept (value , inner_type )
537+ current_subjects : list [Expression ] = [
538+ IndexExpr (s , key ) for s in self .subject_context [- 1 ]
539+ ]
540+ pattern_type = self .accept (value , inner_type , current_subjects )
482541 if is_uninhabited (pattern_type .type ):
483542 can_match = False
484543 else :
485544 self .update_type_map (captures , pattern_type .captures )
545+ if current_subjects :
546+ self .update_type_map (
547+ captures , {subject : pattern_type .type for subject in current_subjects }
548+ )
486549
487550 if o .rest is not None :
488551 mapping = self .chk .named_type ("typing.Mapping" )
@@ -590,7 +653,7 @@ def visit_class_pattern(self, o: ClassPattern) -> PatternType:
590653 if self .should_self_match (typ ):
591654 if len (o .positionals ) > 1 :
592655 self .msg .fail (message_registry .CLASS_PATTERN_TOO_MANY_POSITIONAL_ARGS , o )
593- pattern_type = self .accept (o .positionals [0 ], narrowed_type )
656+ pattern_type = self .accept (o .positionals [0 ], narrowed_type , [] )
594657 if not is_uninhabited (pattern_type .type ):
595658 return PatternType (
596659 pattern_type .type ,
@@ -690,11 +753,20 @@ def visit_class_pattern(self, o: ClassPattern) -> PatternType:
690753 elif keyword is not None :
691754 new_type = self .chk .add_any_attribute_to_type (new_type , keyword )
692755
693- inner_type , inner_rest_type , inner_captures = self .accept (pattern , key_type )
756+ current_subjects : list [Expression ] = []
757+ if keyword is not None :
758+ current_subjects = [MemberExpr (s , keyword ) for s in self .subject_context [- 1 ]]
759+ inner_type , inner_rest_type , inner_captures = self .accept (
760+ pattern , key_type , current_subjects
761+ )
694762 if is_uninhabited (inner_type ):
695763 can_match = False
696764 else :
697765 self .update_type_map (captures , inner_captures )
766+ if current_subjects :
767+ self .update_type_map (
768+ captures , {subject : inner_type for subject in current_subjects }
769+ )
698770 if not is_uninhabited (inner_rest_type ):
699771 rest_type = current_type
700772
@@ -743,17 +815,22 @@ def generate_types_from_names(self, type_names: list[str]) -> list[Type]:
743815 return types
744816
745817 def update_type_map (
746- self , original_type_map : dict [Expression , Type ], extra_type_map : dict [Expression , Type ]
818+ self ,
819+ original_type_map : dict [Expression , Type ],
820+ extra_type_map : dict [Expression , Type ],
821+ fail_multiple_assignments : bool = True ,
747822 ) -> None :
748823 # Calculating this would not be needed if TypeMap directly used literal hashes instead of
749824 # expressions, as suggested in the TODO above it's definition
750825 already_captured = {literal_hash (expr ) for expr in original_type_map }
751826 for expr , typ in extra_type_map .items ():
752827 if literal_hash (expr ) in already_captured :
753- node = get_var (expr )
754- self .msg .fail (
755- message_registry .MULTIPLE_ASSIGNMENTS_IN_PATTERN .format (node .name ), expr
756- )
828+ if (node := get_var (expr )) is None :
829+ continue
830+ if fail_multiple_assignments :
831+ self .msg .fail (
832+ message_registry .MULTIPLE_ASSIGNMENTS_IN_PATTERN .format (node .name ), expr
833+ )
757834 else :
758835 original_type_map [expr ] = typ
759836
@@ -805,12 +882,17 @@ def get_match_arg_names(typ: TupleType) -> list[str | None]:
805882 return args
806883
807884
808- def get_var (expr : Expression ) -> Var :
885+ def get_var (expr : Expression ) -> Var | None :
809886 """
810887 Warning: this in only true for expressions captured by a match statement.
811888 Don't call it from anywhere else
812889 """
813- assert isinstance (expr , NameExpr ), expr
890+ if isinstance (expr , MemberExpr ):
891+ return get_var (expr .expr )
892+ if isinstance (expr , IndexExpr ):
893+ return get_var (expr .base )
894+ if not isinstance (expr , NameExpr ):
895+ return None
814896 node = expr .node
815897 assert isinstance (node , Var ), node
816898 return node
0 commit comments