diff --git a/functorch/_src/aot_autograd.py b/functorch/_src/aot_autograd.py index 57a7ac68f..c4809b7c7 100644 --- a/functorch/_src/aot_autograd.py +++ b/functorch/_src/aot_autograd.py @@ -54,7 +54,7 @@ def _dict_unflatten(values: List[Any], context: Context) -> Dict[Any, Any]: def create_joint_forward_backward(fn): def joint_forward_backward( - primals: List[Any], tangents: List[Any] + primals: List[Any], cotangents: List[Any] ) -> Tuple[List[Any], List[Any]]: # Call the forward pass outs = fn(*primals) @@ -68,21 +68,21 @@ def joint_forward_backward( grad_primals.append(p) # Get the outputs that need gradients - assert len(tangents) == len(outs) + assert len(cotangents) == len(outs) needed_outs = [] - needed_tangents = [] - for out, tangent in zip(outs, tangents): + needed_cotangents = [] + for out, cotangent in zip(outs, cotangents): if isinstance(out, Tensor) and out.requires_grad: needed_outs.append(out) - needed_tangents.append(tangent) + needed_cotangents.append(cotangent) backward_out = [] # Call the backwards pass if grad_primals: backward_out = torch.autograd.grad( needed_outs, grad_primals, - grad_outputs=needed_tangents, - allow_unused=True, + grad_outputs=needed_cotangents, + allow_unused=True ) backward_out_iter = iter(backward_out) return outs, [ @@ -140,12 +140,13 @@ def create_aot_autograd_function( compiled_fw = None compiled_bw = None num_outs = None + aot_decompositions = {**aot_autograd_decompositions, **decompositions} class CompiledFunction(torch.autograd.Function): @staticmethod @disable_torchdynamo def forward(ctx, *flat_tensor_args): - nonlocal compiled_fw, compiled_bw, num_outs + nonlocal compiled_fw, num_outs if compiled_fw is None: with torch.set_grad_enabled(grad_state): out = flat_fn(*flat_tensor_args) @@ -159,31 +160,83 @@ def forward(ctx, *flat_tensor_args): num_outs = 1 joint_inputs = (flat_tensor_args, out) - aot_decompositions = {**aot_autograd_decompositions, **decompositions} + # Need it because autograd.Function disables grad in forward with torch.set_grad_enabled(grad_state): fx_g = make_fx(joint_forward_backward, aot_decompositions)( *joint_inputs ) fw_module, bw_module = partition_fn(fx_g, joint_inputs) - # print(fw_module.code, bw_module.code) compiled_fw = fw_compiler(fw_module, flat_tensor_args) fw_outs = normalize_as_list(compiled_fw(*flat_tensor_args)) - - bw_args = fw_outs[num_outs:] + fw_outs[0:num_outs] - compiled_bw = bw_compiler(bw_module, bw_args) + if partition_fn is default_partition: + print("ENTERING default_partition") + ctx.num_intermediate = len(fw_outs[num_outs:]) + ctx.num_inputs = len(flat_tensor_args) + to_be_saved = fw_outs[num_outs:] + list(flat_tensor_args) + out + print("fw outs: ", fw_outs, "-------") + ctx.save_for_backward(*to_be_saved) + ctx.fwd_graph = fw_module.code + else: + nonlocal compiled_bw + bw_args = fw_outs[num_outs:] + fw_outs[0:num_outs] + compiled_bw = bw_compiler(bw_module, bw_args) + ctx.save_for_backward(*fw_outs[num_outs:]) else: fw_outs = normalize_as_list(compiled_fw(*flat_tensor_args)) - ctx.save_for_backward(*fw_outs[num_outs:]) + if partition_fn is default_partition: + with torch.set_grad_enabled(grad_state): + out = flat_fn(*flat_tensor_args) + out = pytree.tree_map( + lambda x: x.detach().contiguous() if isinstance(x, Tensor) else x, out + ) + ctx.num_intermediate = len(fw_outs[num_outs:]) + ctx.num_inputs = len(flat_tensor_args) + to_be_saved = fw_outs[num_outs:] + list(flat_tensor_args) + out + ctx.save_for_backward(*to_be_saved) + else: + ctx.save_for_backward(*fw_outs[num_outs:]) return tuple(fw_outs[0:num_outs]) @staticmethod @disable_torchdynamo - def backward(ctx, *flat_args): - contiguous_args = [t.contiguous() for t in flat_args] - # contiguous_args = [t for t in flat_args] - out = normalize_as_list(compiled_bw(*ctx.saved_tensors, *contiguous_args)) - return tuple(out) + def backward(ctx, *flat_grad_outs): + print(flat_grad_outs) + contiguous_args = [t.contiguous() for t in flat_grad_outs] + if compiled_bw is None: + assert partition_fn is default_partition + with torch.set_grad_enabled(grad_state): + inputs = ctx.saved_tensors[ctx.num_intermediate:ctx.num_intermediate+ctx.num_inputs] + fx_g = make_fx(joint_forward_backward, aot_decompositions)(inputs, contiguous_args) + # assert that the forward graph generated here is the same + # if it's specified that the user might want to calculate double backwards + fw_module, bw_module = partition_fn(fx_g, ctx.saved_tensors[ctx.num_intermediate:]) + print(fw_module.code) + print(ctx.fwd_graph) + assert fw_module.code == ctx.fwd_graph + func_code = bw_module.code.split('self, ') + # print(func_code[0] + func_code[1]) + exec(func_code[0] + func_code[1], globals()) + f = create_aot_autograd_function(forward, bw_compiler, bw_compiler, partition_fn, aot_decompositions, grad_state) + # print(bw_module.code, *ctx.saved_tensors, contiguous_args) + # print(*ctx.saved_tensors[:ctx.num_intermediate], *contiguous_args) + # print(*ctx.saved_tensors[ctx.num_intermediate:], *contiguous_args) + return f.apply(*ctx.saved_tensors[:ctx.num_intermediate], *contiguous_args) + else: + assert not torch.is_grad_enabled() + out = normalize_as_list(compiled_bw(*ctx.saved_tensors, *contiguous_args)) + return tuple(out) + # nonlocal compiled_bw + # contiguous_args = [t.contiguous() for t in flat_grad_outs] + # if compiled_bw is None: + # with torch.set_grad_enabled(grad_state): + # fx_g = make_fx(joint_forward_backward, aot_decompositions)(joint_inputs[0], contiguous_args) + # # assert that the forward graph generated here is the same + # # if it's specified that the user might want to calculate double backwards + # fw_module, bw_module = partition_fn(fx_g, joint_inputs) + # compiled_bw = bw_compiler(bw_module, fw_outs[num_outs:] + contiguous_args) + # out = normalize_as_list(compiled_bw(*ctx.saved_tensors, *contiguous_args)) + # return tuple(out) return CompiledFunction diff --git a/test/test_pythonkey.py b/test/test_pythonkey.py index ae399fc81..faf0a55de 100644 --- a/test/test_pythonkey.py +++ b/test/test_pythonkey.py @@ -246,14 +246,42 @@ def f(args, kwargs): def _outs_and_grads(fn, inps): outs = fn(*inps) + diff_outs = [] for out in pytree.tree_flatten(outs)[0]: if isinstance(out, torch.Tensor) and out.requires_grad: - out.sum().backward(retain_graph=True) - grads = [inp.grad for inp in pytree.tree_flatten(inps)[0]] - for inp in pytree.tree_flatten(inps)[0]: - inp.grad = None + diff_outs.append(out) + def full_reduce(outs): + res = 0 + for out in outs: + res=res+out.sum() + return res + print(inps) + grads = torch.autograd.grad(full_reduce(diff_outs), pytree.tree_flatten(inps)[0], create_graph=True) return outs, grads +def _outs_and_grads_and_grad_grads(fn, inps): + outs = fn(*inps) + diff_outs = [] + diff_inps = [] + for out in pytree.tree_flatten(outs)[0]: + if isinstance(out, torch.Tensor) and out.requires_grad: + diff_outs.append(out) + for inp in pytree.tree_flatten(inps)[0]: + if isinstance(inp, torch.Tensor) and inp.requires_grad: + diff_inps.append(inp) + def full_reduce(outs): + res = 0 + for out in outs: + res=res+out.sum() + return res + grads = torch.autograd.grad(full_reduce(diff_outs), diff_inps, create_graph=True) + print("grads: ", grads) + diff_grads = [] + for grad_ in grads: + if isinstance(grad_, torch.Tensor) and grad_.requires_grad: + diff_grads.append(grad_) + grad_grads = torch.autograd.grad(full_reduce(diff_grads), diff_inps) + return outs, grads, grad_grads class TestAOTAutograd(TestCase): def verify_aot_autograd(self, f, inp): @@ -261,10 +289,11 @@ def verify_aot_autograd(self, f, inp): compiled_f = aot_module(f, nop) else: compiled_f = aot_function(f, nop) - ref_out, ref_grad = _outs_and_grads(f, inp) - test_out, test_grad = _outs_and_grads(compiled_f, inp) + ref_out, ref_grad, ref_grad_grad = _outs_and_grads_and_grad_grads(f, inp) + test_out, test_grad, test_grad_grad = _outs_and_grads_and_grad_grads(compiled_f, inp) self.assertEqual(ref_out, test_out) self.assertEqual(ref_grad, test_grad) + # self.assertEqual(ref_grad_grad, test_grad_grad) def test_single_output(self): def f(a, b):