Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Separate forward and backwad compilation and support higher order derivatives for aot_function #856

Open
wants to merge 12 commits into
base: gh/anjali411/1/base
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 18 additions & 10 deletions functorch/_src/aot_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def _dict_unflatten(values: List[Any], context: Context) -> Dict[Any, Any]:


def create_joint_forward_backward(fn):
# tangents are just grad_outs/cotangents (wrong naming)
anjali411 marked this conversation as resolved.
Show resolved Hide resolved
def joint_forward_backward(
primals: List[Any], tangents: List[Any]
) -> Tuple[List[Any], List[Any]]:
Expand Down Expand Up @@ -140,12 +141,14 @@ def create_aot_autograd_function(
compiled_fw = None
compiled_bw = None
num_outs = None

joint_inputs = None
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

need to save these tensors in the context

fw_outs = None
aot_decompositions = {**aot_autograd_decompositions, **decompositions}
class CompiledFunction(torch.autograd.Function):
@staticmethod
@disable_torchdynamo
def forward(ctx, *flat_tensor_args):
nonlocal compiled_fw, compiled_bw, num_outs
nonlocal compiled_fw, num_outs, joint_inputs, fw_outs
if compiled_fw is None:
with torch.set_grad_enabled(grad_state):
out = flat_fn(*flat_tensor_args)
Expand All @@ -159,29 +162,34 @@ def forward(ctx, *flat_tensor_args):
num_outs = 1

joint_inputs = (flat_tensor_args, out)
aot_decompositions = {**aot_autograd_decompositions, **decompositions}
# Need it because autograd.Function disables grad in forward
with torch.set_grad_enabled(grad_state):
fx_g = make_fx(joint_forward_backward, aot_decompositions)(
*joint_inputs
)
fw_module, bw_module = partition_fn(fx_g, joint_inputs)
# print(fw_module.code, bw_module.code)

compiled_fw = fw_compiler(fw_module, flat_tensor_args)
fw_outs = normalize_as_list(compiled_fw(*flat_tensor_args))

bw_args = fw_outs[num_outs:] + fw_outs[0:num_outs]
compiled_bw = bw_compiler(bw_module, bw_args)
if partition_fn is default_partition:
nonlocal compiled_bw
bw_args = fw_outs[num_outs:] + fw_outs[0:num_outs]
compiled_bw = bw_compiler(bw_module, bw_args)
else:
fw_outs = normalize_as_list(compiled_fw(*flat_tensor_args))
ctx.save_for_backward(*fw_outs[num_outs:])
return tuple(fw_outs[0:num_outs])

@staticmethod
@disable_torchdynamo
def backward(ctx, *flat_args):
contiguous_args = [t.contiguous() for t in flat_args]
# contiguous_args = [t for t in flat_args]
def backward(ctx, *flat_grad_outs):
nonlocal compiled_bw
contiguous_args = [t.contiguous() for t in flat_grad_outs]
if compiled_bw is None:
with torch.set_grad_enabled(grad_state):
fx_g = make_fx(joint_forward_backward, aot_decompositions)(joint_inputs[0], contiguous_args)
fw_module, bw_module = partition_fn(fx_g, joint_inputs)
compiled_bw = bw_compiler(bw_module, fw_outs[num_outs:] + contiguous_args)
out = normalize_as_list(compiled_bw(*ctx.saved_tensors, *contiguous_args))
return tuple(out)

Expand Down