From 2499ce217a3639c7f9485ce62525240844f81bed Mon Sep 17 00:00:00 2001 From: Pierre Sassoulas Date: Sun, 14 Sep 2025 13:52:06 +0200 Subject: [PATCH 1/4] [refactor to use match] AssertionRewriter.run() Co-authored-by: Marc Mueller <30130371+cdce8p@users.noreply.github.com> --- src/_pytest/assertion/rewrite.py | 29 +++++++++++------------------ 1 file changed, 11 insertions(+), 18 deletions(-) diff --git a/src/_pytest/assertion/rewrite.py b/src/_pytest/assertion/rewrite.py index bff33ccf155..df79ae98d3b 100644 --- a/src/_pytest/assertion/rewrite.py +++ b/src/_pytest/assertion/rewrite.py @@ -703,24 +703,17 @@ def run(self, mod: ast.Module) -> None: return pos = 0 for item in mod.body: - if ( - expect_docstring - and isinstance(item, ast.Expr) - and isinstance(item.value, ast.Constant) - and isinstance(item.value.value, str) - ): - doc = item.value.value - if self.is_rewrite_disabled(doc): - return - expect_docstring = False - elif ( - isinstance(item, ast.ImportFrom) - and item.level == 0 - and item.module == "__future__" - ): - pass - else: - break + match item: + case ast.Expr(value=ast.Constant(value=str() as doc)) if ( + expect_docstring + ): + if self.is_rewrite_disabled(doc): + return + expect_docstring = False + case ast.ImportFrom(level=0, module="__future__"): + pass + case _: + break pos += 1 # Special case: for a decorated function, set the lineno to that of the # first decorator, not the `def`. Issue #4984. From ef7a432896ed8ce0dd0316597b18f2260037bf6b Mon Sep 17 00:00:00 2001 From: Pierre Sassoulas Date: Sun, 14 Sep 2025 14:00:27 +0200 Subject: [PATCH 2/4] [refactor to use match] AssertionRewriter.visit_Compare() --- src/_pytest/assertion/rewrite.py | 28 +++++++++++++++------------- 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/src/_pytest/assertion/rewrite.py b/src/_pytest/assertion/rewrite.py index df79ae98d3b..4ef68f28082 100644 --- a/src/_pytest/assertion/rewrite.py +++ b/src/_pytest/assertion/rewrite.py @@ -1112,12 +1112,13 @@ def visit_Attribute(self, attr: ast.Attribute) -> tuple[ast.Name, str]: def visit_Compare(self, comp: ast.Compare) -> tuple[ast.expr, str]: self.push_format_context() # We first check if we have overwritten a variable in the previous assert - if isinstance( - comp.left, ast.Name - ) and comp.left.id in self.variables_overwrite.get(self.scope, {}): - comp.left = self.variables_overwrite[self.scope][comp.left.id] # type:ignore[assignment] - if isinstance(comp.left, ast.NamedExpr): - self.variables_overwrite[self.scope][comp.left.target.id] = comp.left # type:ignore[assignment] + match comp.left: + case ast.Name(id=name_id) if name_id in self.variables_overwrite.get( + self.scope, {} + ): + comp.left = self.variables_overwrite[self.scope][name_id] # type: ignore[assignment] + case ast.NamedExpr(target=ast.Name(id=target_id)): + self.variables_overwrite[self.scope][target_id] = comp.left # type: ignore[assignment] left_res, left_expl = self.visit(comp.left) if isinstance(comp.left, ast.Compare | ast.BoolOp): left_expl = f"({left_expl})" @@ -1129,13 +1130,14 @@ def visit_Compare(self, comp: ast.Compare) -> tuple[ast.expr, str]: syms: list[ast.expr] = [] results = [left_res] for i, op, next_operand in it: - if ( - isinstance(next_operand, ast.NamedExpr) - and isinstance(left_res, ast.Name) - and next_operand.target.id == left_res.id - ): - next_operand.target.id = self.variable() - self.variables_overwrite[self.scope][left_res.id] = next_operand # type:ignore[assignment] + match (next_operand, left_res): + case ( + ast.NamedExpr(target=ast.Name(id=target_id)), + ast.Name(id=name_id), + ) if target_id == name_id: + next_operand.target.id = self.variable() + self.variables_overwrite[self.scope][name_id] = next_operand # type: ignore[assignment] + next_res, next_expl = self.visit(next_operand) if isinstance(next_operand, ast.Compare | ast.BoolOp): next_expl = f"({next_expl})" From b4e8769da499bff9020a0726cf7b83a3b77fac84 Mon Sep 17 00:00:00 2001 From: Pierre Sassoulas Date: Sun, 14 Sep 2025 14:26:53 +0200 Subject: [PATCH 3/4] [refactor to use match] AssertionRewriter.visit_BoolOp() --- src/_pytest/assertion/rewrite.py | 25 +++++++++++-------------- testing/test_assertrewrite.py | 20 +++++++++++++++++++- 2 files changed, 30 insertions(+), 15 deletions(-) diff --git a/src/_pytest/assertion/rewrite.py b/src/_pytest/assertion/rewrite.py index 4ef68f28082..d65d85be44b 100644 --- a/src/_pytest/assertion/rewrite.py +++ b/src/_pytest/assertion/rewrite.py @@ -1010,20 +1010,17 @@ def visit_BoolOp(self, boolop: ast.BoolOp) -> tuple[ast.Name, str]: # cond is set in a prior loop iteration below self.expl_stmts.append(ast.If(cond, fail_inner, [])) # noqa: F821 self.expl_stmts = fail_inner - # Check if the left operand is a ast.NamedExpr and the value has already been visited - if ( - isinstance(v, ast.Compare) - and isinstance(v.left, ast.NamedExpr) - and v.left.target.id - in [ - ast_expr.id - for ast_expr in boolop.values[:i] - if hasattr(ast_expr, "id") - ] - ): - pytest_temp = self.variable() - self.variables_overwrite[self.scope][v.left.target.id] = v.left # type:ignore[assignment] - v.left.target.id = pytest_temp + match v: + # Check if the left operand is an ast.NamedExpr and the value has already been visited + case ast.Compare( + left=ast.NamedExpr(target=ast.Name(id=target_id)) + ) if target_id in [ + e.id for e in boolop.values[:i] if hasattr(e, "id") + ]: + pytest_temp = self.variable() + self.variables_overwrite[self.scope][target_id] = v.left # type:ignore[assignment] + # mypy's false positive, we're checking that the 'target' attribute exists. + v.left.target.id = pytest_temp # type:ignore[attr-defined] self.push_format_context() res, expl = self.visit(v) body.append(ast.Assign([ast.Name(res_var, ast.Store())], res)) diff --git a/testing/test_assertrewrite.py b/testing/test_assertrewrite.py index 18bc32dc86f..92664354470 100644 --- a/testing/test_assertrewrite.py +++ b/testing/test_assertrewrite.py @@ -1552,7 +1552,9 @@ def test_simple_failure(): result.stdout.fnmatch_lines(["*E*assert (1 + 1) == 3"]) -class TestIssue10743: +class TestAssertionRewriteWalrusOperator: + """See #10743""" + def test_assertion_walrus_operator(self, pytester: Pytester) -> None: pytester.makepyfile( """ @@ -1719,6 +1721,22 @@ def test_walrus_operator_not_override_value(): result = pytester.runpytest() assert result.ret == 0 + def test_assertion_namedexpr_compare_left_overwrite( + self, pytester: Pytester + ) -> None: + pytester.makepyfile( + """ + def test_namedexpr_compare_left_overwrite(): + a = "Hello" + b = "World" + c = "Test" + assert (a := b) == c and (a := "Test") == "Test" + """ + ) + result = pytester.runpytest() + assert result.ret == 1 + result.stdout.fnmatch_lines(["*assert ('World' == 'Test'*"]) + class TestIssue11028: def test_assertion_walrus_operator_in_operand(self, pytester: Pytester) -> None: From 7308a725ef20bfd979c66555c3eb586d5d75fcb8 Mon Sep 17 00:00:00 2001 From: Pierre Sassoulas Date: Sun, 14 Sep 2025 14:29:30 +0200 Subject: [PATCH 4/4] [refactor to use match] AssertionRewriter.visit_Call() --- src/_pytest/assertion/rewrite.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/_pytest/assertion/rewrite.py b/src/_pytest/assertion/rewrite.py index d65d85be44b..566549d66f2 100644 --- a/src/_pytest/assertion/rewrite.py +++ b/src/_pytest/assertion/rewrite.py @@ -1070,10 +1070,11 @@ def visit_Call(self, call: ast.Call) -> tuple[ast.Name, str]: arg_expls.append(expl) new_args.append(res) for keyword in call.keywords: - if isinstance( - keyword.value, ast.Name - ) and keyword.value.id in self.variables_overwrite.get(self.scope, {}): - keyword.value = self.variables_overwrite[self.scope][keyword.value.id] # type:ignore[assignment] + match keyword.value: + case ast.Name(id=id) if id in self.variables_overwrite.get( + self.scope, {} + ): + keyword.value = self.variables_overwrite[self.scope][id] # type:ignore[assignment] res, expl = self.visit(keyword.value) new_kwargs.append(ast.keyword(keyword.arg, res)) if keyword.arg: