Skip to content

Commit 3cb1208

Browse files
[refactor to use match] AssertionRewriter.visit_Compare()
1 parent bca78ad commit 3cb1208

File tree

1 file changed

+15
-13
lines changed

1 file changed

+15
-13
lines changed

src/_pytest/assertion/rewrite.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1110,12 +1110,13 @@ def visit_Attribute(self, attr: ast.Attribute) -> tuple[ast.Name, str]:
11101110
def visit_Compare(self, comp: ast.Compare) -> tuple[ast.expr, str]:
11111111
self.push_format_context()
11121112
# We first check if we have overwritten a variable in the previous assert
1113-
if isinstance(
1114-
comp.left, ast.Name
1115-
) and comp.left.id in self.variables_overwrite.get(self.scope, {}):
1116-
comp.left = self.variables_overwrite[self.scope][comp.left.id] # type:ignore[assignment]
1117-
if isinstance(comp.left, ast.NamedExpr):
1118-
self.variables_overwrite[self.scope][comp.left.target.id] = comp.left # type:ignore[assignment]
1113+
match comp.left:
1114+
case ast.Name(id=name_id) if name_id in self.variables_overwrite.get(
1115+
self.scope, {}
1116+
):
1117+
comp.left = self.variables_overwrite[self.scope][name_id] # type: ignore[assignment]
1118+
case ast.NamedExpr(target=ast.Name(id=target_id)):
1119+
self.variables_overwrite[self.scope][target_id] = comp.left # type: ignore[assignment]
11191120
left_res, left_expl = self.visit(comp.left)
11201121
if isinstance(comp.left, ast.Compare | ast.BoolOp):
11211122
left_expl = f"({left_expl})"
@@ -1127,13 +1128,14 @@ def visit_Compare(self, comp: ast.Compare) -> tuple[ast.expr, str]:
11271128
syms: list[ast.expr] = []
11281129
results = [left_res]
11291130
for i, op, next_operand in it:
1130-
if (
1131-
isinstance(next_operand, ast.NamedExpr)
1132-
and isinstance(left_res, ast.Name)
1133-
and next_operand.target.id == left_res.id
1134-
):
1135-
next_operand.target.id = self.variable()
1136-
self.variables_overwrite[self.scope][left_res.id] = next_operand # type:ignore[assignment]
1131+
match (next_operand, left_res):
1132+
case (
1133+
ast.NamedExpr(target=ast.Name(id=target_id)),
1134+
ast.Name(id=name_id),
1135+
) if target_id == name_id:
1136+
next_operand.target.id = self.variable()
1137+
self.variables_overwrite[self.scope][name_id] = next_operand # type: ignore[assignment]
1138+
11371139
next_res, next_expl = self.visit(next_operand)
11381140
if isinstance(next_operand, ast.Compare | ast.BoolOp):
11391141
next_expl = f"({next_expl})"

0 commit comments

Comments
 (0)