Skip to content

Commit

Permalink
[SOT] Dont use get_py_value from Tensor default in MAKE_FUNCTION
Browse files Browse the repository at this point in the history
…to avoid breakgraph (#71048)
  • Loading branch information
SigureMo authored Feb 8, 2025
1 parent 0f78ead commit d71159b
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -1509,7 +1509,7 @@ def attach_new_attribute(self, flag, related_list):
kw_default_args_variable = self.stack.pop()
assert isinstance(kw_default_args_variable, DictVariable)
related_list.append(kw_default_args_variable)
kw_defaults = kw_default_args_variable.get_py_value()
kw_defaults = kw_default_args_variable.get_wrapped_items()

if flag & MF.MF_HAS_DEFAULTS:
'''
Expand Down
22 changes: 22 additions & 0 deletions test/sot/test_13_make_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,24 @@ def fn(a: int = 1, b: float = 2.0, /, *, c: int = 4, d: float = 5):
return fn(2, 3, c=1, d=2.0) + x


def make_fn_tensor_default(x: paddle.Tensor):
tensor = paddle.to_tensor(1.0)

def fn(a, b, c=tensor):
return a + b + c

return fn(1, 2) + x


def make_fn_tensor_kwdefault(x: paddle.Tensor):
tensor = paddle.to_tensor(1.0)

def fn(*args, c=tensor):
return args[0] + args[1] + c

return fn(1, 2, c=3) + x


class TestMakeFunction(TestCaseBase):
def test_simple(self):
self.assert_results(make_fn_simple, paddle.to_tensor(1))
Expand All @@ -76,6 +94,10 @@ def test_simple(self):
self.assert_results(make_fn_closure, paddle.to_tensor(1))
self.assert_results(make_fn_mix, paddle.to_tensor(1))

def test_tensor_default(self):
self.assert_results(make_fn_tensor_default, paddle.to_tensor(1))
self.assert_results(make_fn_tensor_kwdefault, paddle.to_tensor(1))


if __name__ == "__main__":
unittest.main()

0 comments on commit d71159b

Please sign in to comment.