Skip to content

Commit

Permalink
Separate forward and backwad compilation for default partition
Browse files Browse the repository at this point in the history
ghstack-source-id: c24ee1b8c252d9aebe99b0beb9139dd3eb223dd4
Pull Request resolved: #856
  • Loading branch information
anjali411 committed Jun 9, 2022
1 parent 130582c commit 4d7096b
Show file tree
Hide file tree
Showing 2 changed files with 107 additions and 25 deletions.
91 changes: 72 additions & 19 deletions functorch/_src/aot_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def _dict_unflatten(values: List[Any], context: Context) -> Dict[Any, Any]:

def create_joint_forward_backward(fn):
def joint_forward_backward(
primals: List[Any], tangents: List[Any]
primals: List[Any], cotangents: List[Any]
) -> Tuple[List[Any], List[Any]]:
# Call the forward pass
outs = fn(*primals)
Expand All @@ -68,21 +68,21 @@ def joint_forward_backward(
grad_primals.append(p)

# Get the outputs that need gradients
assert len(tangents) == len(outs)
assert len(cotangents) == len(outs)
needed_outs = []
needed_tangents = []
for out, tangent in zip(outs, tangents):
needed_cotangents = []
for out, cotangent in zip(outs, cotangents):
if isinstance(out, Tensor) and out.requires_grad:
needed_outs.append(out)
needed_tangents.append(tangent)
needed_cotangents.append(cotangent)
backward_out = []
# Call the backwards pass
if grad_primals:
backward_out = torch.autograd.grad(
needed_outs,
grad_primals,
grad_outputs=needed_tangents,
allow_unused=True,
grad_outputs=needed_cotangents,
allow_unused=True
)
backward_out_iter = iter(backward_out)
return outs, [
Expand Down Expand Up @@ -140,12 +140,13 @@ def create_aot_autograd_function(
compiled_fw = None
compiled_bw = None
num_outs = None
aot_decompositions = {**aot_autograd_decompositions, **decompositions}

class CompiledFunction(torch.autograd.Function):
@staticmethod
@disable_torchdynamo
def forward(ctx, *flat_tensor_args):
nonlocal compiled_fw, compiled_bw, num_outs
nonlocal compiled_fw, num_outs
if compiled_fw is None:
with torch.set_grad_enabled(grad_state):
out = flat_fn(*flat_tensor_args)
Expand All @@ -159,31 +160,83 @@ def forward(ctx, *flat_tensor_args):
num_outs = 1

joint_inputs = (flat_tensor_args, out)
aot_decompositions = {**aot_autograd_decompositions, **decompositions}
# Need it because autograd.Function disables grad in forward
with torch.set_grad_enabled(grad_state):
fx_g = make_fx(joint_forward_backward, aot_decompositions)(
*joint_inputs
)
fw_module, bw_module = partition_fn(fx_g, joint_inputs)
# print(fw_module.code, bw_module.code)

compiled_fw = fw_compiler(fw_module, flat_tensor_args)
fw_outs = normalize_as_list(compiled_fw(*flat_tensor_args))

bw_args = fw_outs[num_outs:] + fw_outs[0:num_outs]
compiled_bw = bw_compiler(bw_module, bw_args)
if partition_fn is default_partition:
print("ENTERING default_partition")
ctx.num_intermediate = len(fw_outs[num_outs:])
ctx.num_inputs = len(flat_tensor_args)
to_be_saved = fw_outs[num_outs:] + list(flat_tensor_args) + out
print("fw outs: ", fw_outs, "-------")
ctx.save_for_backward(*to_be_saved)
ctx.fwd_graph = fw_module.code
else:
nonlocal compiled_bw
bw_args = fw_outs[num_outs:] + fw_outs[0:num_outs]
compiled_bw = bw_compiler(bw_module, bw_args)
ctx.save_for_backward(*fw_outs[num_outs:])
else:
fw_outs = normalize_as_list(compiled_fw(*flat_tensor_args))
ctx.save_for_backward(*fw_outs[num_outs:])
if partition_fn is default_partition:
with torch.set_grad_enabled(grad_state):
out = flat_fn(*flat_tensor_args)
out = pytree.tree_map(
lambda x: x.detach().contiguous() if isinstance(x, Tensor) else x, out
)
ctx.num_intermediate = len(fw_outs[num_outs:])
ctx.num_inputs = len(flat_tensor_args)
to_be_saved = fw_outs[num_outs:] + list(flat_tensor_args) + out
ctx.save_for_backward(*to_be_saved)
else:
ctx.save_for_backward(*fw_outs[num_outs:])
return tuple(fw_outs[0:num_outs])

@staticmethod
@disable_torchdynamo
def backward(ctx, *flat_args):
contiguous_args = [t.contiguous() for t in flat_args]
# contiguous_args = [t for t in flat_args]
out = normalize_as_list(compiled_bw(*ctx.saved_tensors, *contiguous_args))
return tuple(out)
def backward(ctx, *flat_grad_outs):
print(flat_grad_outs)
contiguous_args = [t.contiguous() for t in flat_grad_outs]
if compiled_bw is None:
assert partition_fn is default_partition
with torch.set_grad_enabled(grad_state):
inputs = ctx.saved_tensors[ctx.num_intermediate:ctx.num_intermediate+ctx.num_inputs]
fx_g = make_fx(joint_forward_backward, aot_decompositions)(inputs, contiguous_args)
# assert that the forward graph generated here is the same
# if it's specified that the user might want to calculate double backwards
fw_module, bw_module = partition_fn(fx_g, ctx.saved_tensors[ctx.num_intermediate:])
print(fw_module.code)
print(ctx.fwd_graph)
assert fw_module.code == ctx.fwd_graph
func_code = bw_module.code.split('self, ')
# print(func_code[0] + func_code[1])
exec(func_code[0] + func_code[1], globals())
f = create_aot_autograd_function(forward, bw_compiler, bw_compiler, partition_fn, aot_decompositions, grad_state)
# print(bw_module.code, *ctx.saved_tensors, contiguous_args)
# print(*ctx.saved_tensors[:ctx.num_intermediate], *contiguous_args)
# print(*ctx.saved_tensors[ctx.num_intermediate:], *contiguous_args)
return f.apply(*ctx.saved_tensors[:ctx.num_intermediate], *contiguous_args)
else:
assert not torch.is_grad_enabled()
out = normalize_as_list(compiled_bw(*ctx.saved_tensors, *contiguous_args))
return tuple(out)
# nonlocal compiled_bw
# contiguous_args = [t.contiguous() for t in flat_grad_outs]
# if compiled_bw is None:
# with torch.set_grad_enabled(grad_state):
# fx_g = make_fx(joint_forward_backward, aot_decompositions)(joint_inputs[0], contiguous_args)
# # assert that the forward graph generated here is the same
# # if it's specified that the user might want to calculate double backwards
# fw_module, bw_module = partition_fn(fx_g, joint_inputs)
# compiled_bw = bw_compiler(bw_module, fw_outs[num_outs:] + contiguous_args)
# out = normalize_as_list(compiled_bw(*ctx.saved_tensors, *contiguous_args))
# return tuple(out)

return CompiledFunction

Expand Down
41 changes: 35 additions & 6 deletions test/test_pythonkey.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,25 +246,54 @@ def f(args, kwargs):

def _outs_and_grads(fn, inps):
outs = fn(*inps)
diff_outs = []
for out in pytree.tree_flatten(outs)[0]:
if isinstance(out, torch.Tensor) and out.requires_grad:
out.sum().backward(retain_graph=True)
grads = [inp.grad for inp in pytree.tree_flatten(inps)[0]]
for inp in pytree.tree_flatten(inps)[0]:
inp.grad = None
diff_outs.append(out)
def full_reduce(outs):
res = 0
for out in outs:
res=res+out.sum()
return res
print(inps)
grads = torch.autograd.grad(full_reduce(diff_outs), pytree.tree_flatten(inps)[0], create_graph=True)
return outs, grads

def _outs_and_grads_and_grad_grads(fn, inps):
outs = fn(*inps)
diff_outs = []
diff_inps = []
for out in pytree.tree_flatten(outs)[0]:
if isinstance(out, torch.Tensor) and out.requires_grad:
diff_outs.append(out)
for inp in pytree.tree_flatten(inps)[0]:
if isinstance(inp, torch.Tensor) and inp.requires_grad:
diff_inps.append(inp)
def full_reduce(outs):
res = 0
for out in outs:
res=res+out.sum()
return res
grads = torch.autograd.grad(full_reduce(diff_outs), diff_inps, create_graph=True)
print("grads: ", grads)
diff_grads = []
for grad_ in grads:
if isinstance(grad_, torch.Tensor) and grad_.requires_grad:
diff_grads.append(grad_)
grad_grads = torch.autograd.grad(full_reduce(diff_grads), diff_inps)
return outs, grads, grad_grads

class TestAOTAutograd(TestCase):
def verify_aot_autograd(self, f, inp):
if isinstance(f, nn.Module):
compiled_f = aot_module(f, nop)
else:
compiled_f = aot_function(f, nop)
ref_out, ref_grad = _outs_and_grads(f, inp)
test_out, test_grad = _outs_and_grads(compiled_f, inp)
ref_out, ref_grad, ref_grad_grad = _outs_and_grads_and_grad_grads(f, inp)
test_out, test_grad, test_grad_grad = _outs_and_grads_and_grad_grads(compiled_f, inp)
self.assertEqual(ref_out, test_out)
self.assertEqual(ref_grad, test_grad)
# self.assertEqual(ref_grad_grad, test_grad_grad)

def test_single_output(self):
def f(a, b):
Expand Down

0 comments on commit 4d7096b

Please sign in to comment.