diff --git a/src/_pytest/assertion/rewrite.py b/src/_pytest/assertion/rewrite.py index bff33ccf155..566549d66f2 100644 --- a/src/_pytest/assertion/rewrite.py +++ b/src/_pytest/assertion/rewrite.py @@ -703,24 +703,17 @@ def run(self, mod: ast.Module) -> None: return pos = 0 for item in mod.body: - if ( - expect_docstring - and isinstance(item, ast.Expr) - and isinstance(item.value, ast.Constant) - and isinstance(item.value.value, str) - ): - doc = item.value.value - if self.is_rewrite_disabled(doc): - return - expect_docstring = False - elif ( - isinstance(item, ast.ImportFrom) - and item.level == 0 - and item.module == "__future__" - ): - pass - else: - break + match item: + case ast.Expr(value=ast.Constant(value=str() as doc)) if ( + expect_docstring + ): + if self.is_rewrite_disabled(doc): + return + expect_docstring = False + case ast.ImportFrom(level=0, module="__future__"): + pass + case _: + break pos += 1 # Special case: for a decorated function, set the lineno to that of the # first decorator, not the `def`. Issue #4984. @@ -1017,20 +1010,17 @@ def visit_BoolOp(self, boolop: ast.BoolOp) -> tuple[ast.Name, str]: # cond is set in a prior loop iteration below self.expl_stmts.append(ast.If(cond, fail_inner, [])) # noqa: F821 self.expl_stmts = fail_inner - # Check if the left operand is a ast.NamedExpr and the value has already been visited - if ( - isinstance(v, ast.Compare) - and isinstance(v.left, ast.NamedExpr) - and v.left.target.id - in [ - ast_expr.id - for ast_expr in boolop.values[:i] - if hasattr(ast_expr, "id") - ] - ): - pytest_temp = self.variable() - self.variables_overwrite[self.scope][v.left.target.id] = v.left # type:ignore[assignment] - v.left.target.id = pytest_temp + match v: + # Check if the left operand is an ast.NamedExpr and the value has already been visited + case ast.Compare( + left=ast.NamedExpr(target=ast.Name(id=target_id)) + ) if target_id in [ + e.id for e in boolop.values[:i] if hasattr(e, "id") + ]: + pytest_temp = self.variable() + self.variables_overwrite[self.scope][target_id] = v.left # type:ignore[assignment] + # mypy's false positive, we're checking that the 'target' attribute exists. + v.left.target.id = pytest_temp # type:ignore[attr-defined] self.push_format_context() res, expl = self.visit(v) body.append(ast.Assign([ast.Name(res_var, ast.Store())], res)) @@ -1080,10 +1070,11 @@ def visit_Call(self, call: ast.Call) -> tuple[ast.Name, str]: arg_expls.append(expl) new_args.append(res) for keyword in call.keywords: - if isinstance( - keyword.value, ast.Name - ) and keyword.value.id in self.variables_overwrite.get(self.scope, {}): - keyword.value = self.variables_overwrite[self.scope][keyword.value.id] # type:ignore[assignment] + match keyword.value: + case ast.Name(id=id) if id in self.variables_overwrite.get( + self.scope, {} + ): + keyword.value = self.variables_overwrite[self.scope][id] # type:ignore[assignment] res, expl = self.visit(keyword.value) new_kwargs.append(ast.keyword(keyword.arg, res)) if keyword.arg: @@ -1119,12 +1110,13 @@ def visit_Attribute(self, attr: ast.Attribute) -> tuple[ast.Name, str]: def visit_Compare(self, comp: ast.Compare) -> tuple[ast.expr, str]: self.push_format_context() # We first check if we have overwritten a variable in the previous assert - if isinstance( - comp.left, ast.Name - ) and comp.left.id in self.variables_overwrite.get(self.scope, {}): - comp.left = self.variables_overwrite[self.scope][comp.left.id] # type:ignore[assignment] - if isinstance(comp.left, ast.NamedExpr): - self.variables_overwrite[self.scope][comp.left.target.id] = comp.left # type:ignore[assignment] + match comp.left: + case ast.Name(id=name_id) if name_id in self.variables_overwrite.get( + self.scope, {} + ): + comp.left = self.variables_overwrite[self.scope][name_id] # type: ignore[assignment] + case ast.NamedExpr(target=ast.Name(id=target_id)): + self.variables_overwrite[self.scope][target_id] = comp.left # type: ignore[assignment] left_res, left_expl = self.visit(comp.left) if isinstance(comp.left, ast.Compare | ast.BoolOp): left_expl = f"({left_expl})" @@ -1136,13 +1128,14 @@ def visit_Compare(self, comp: ast.Compare) -> tuple[ast.expr, str]: syms: list[ast.expr] = [] results = [left_res] for i, op, next_operand in it: - if ( - isinstance(next_operand, ast.NamedExpr) - and isinstance(left_res, ast.Name) - and next_operand.target.id == left_res.id - ): - next_operand.target.id = self.variable() - self.variables_overwrite[self.scope][left_res.id] = next_operand # type:ignore[assignment] + match (next_operand, left_res): + case ( + ast.NamedExpr(target=ast.Name(id=target_id)), + ast.Name(id=name_id), + ) if target_id == name_id: + next_operand.target.id = self.variable() + self.variables_overwrite[self.scope][name_id] = next_operand # type: ignore[assignment] + next_res, next_expl = self.visit(next_operand) if isinstance(next_operand, ast.Compare | ast.BoolOp): next_expl = f"({next_expl})" diff --git a/testing/test_assertrewrite.py b/testing/test_assertrewrite.py index 18bc32dc86f..92664354470 100644 --- a/testing/test_assertrewrite.py +++ b/testing/test_assertrewrite.py @@ -1552,7 +1552,9 @@ def test_simple_failure(): result.stdout.fnmatch_lines(["*E*assert (1 + 1) == 3"]) -class TestIssue10743: +class TestAssertionRewriteWalrusOperator: + """See #10743""" + def test_assertion_walrus_operator(self, pytester: Pytester) -> None: pytester.makepyfile( """ @@ -1719,6 +1721,22 @@ def test_walrus_operator_not_override_value(): result = pytester.runpytest() assert result.ret == 0 + def test_assertion_namedexpr_compare_left_overwrite( + self, pytester: Pytester + ) -> None: + pytester.makepyfile( + """ + def test_namedexpr_compare_left_overwrite(): + a = "Hello" + b = "World" + c = "Test" + assert (a := b) == c and (a := "Test") == "Test" + """ + ) + result = pytester.runpytest() + assert result.ret == 1 + result.stdout.fnmatch_lines(["*assert ('World' == 'Test'*"]) + class TestIssue11028: def test_assertion_walrus_operator_in_operand(self, pytester: Pytester) -> None: