Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Separate forward and backwad compilation and support higher order derivatives for aot_function #856

Open
wants to merge 12 commits into
base: gh/anjali411/1/base
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 83 additions & 31 deletions functorch/_src/aot_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from functorch.experimental import functionalize
from . import config
from .decompositions import register_decomposition
from .partitioners import default_partition
from .partitioners import default_partition, _get_saved_values, _extract_fwd_bwd_modules
from .named_members_polyfill import _named_parameters, _named_buffers
from typing import Callable, List, Dict, Any, Tuple, Optional
from functools import wraps
Expand Down Expand Up @@ -70,7 +70,7 @@ def preserve_rng_state():

def create_joint_forward_backward(fn):
def joint_forward_backward(
primals: List[Any], tangents: List[Any]
primals: List[Any], cotangents: List[Any]
) -> Tuple[List[Any], List[Any]]:
# Call the forward pass
outs = fn(*primals)
Expand All @@ -84,21 +84,21 @@ def joint_forward_backward(
grad_primals.append(p)

# Get the outputs that need gradients
assert len(tangents) == len(outs)
assert len(cotangents) == len(outs)
needed_outs = []
needed_tangents = []
for out, tangent in zip(outs, tangents):
needed_cotangents = []
for out, cotangent in zip(outs, cotangents):
if isinstance(out, Tensor) and out.requires_grad:
needed_outs.append(out)
needed_tangents.append(tangent)
needed_cotangents.append(cotangent)
backward_out = []
# Call the backwards pass
if grad_primals:
backward_out = torch.autograd.grad(
needed_outs,
grad_primals,
grad_outputs=needed_tangents,
allow_unused=True,
grad_outputs=needed_cotangents,
allow_unused=True
)
backward_out_iter = iter(backward_out)
return outs, [
Expand Down Expand Up @@ -152,22 +152,31 @@ def create_aot_autograd_function(
if decompositions is None:
decompositions = {}
joint_forward_backward = create_joint_forward_backward(flat_fn)

# create_joint_forward_backward takes inputs and cotangents as inps
# inps: inputs, cotangents: flat_grad_outs
j_b = None
compiled_fw = None
compiled_bw = None
bw_modules = []
num_outs = None
saved_value_names = None
aot_decompositions = {**aot_autograd_decompositions, **decompositions}

class CompiledFunction(torch.autograd.Function):
@staticmethod
@disable_torchdynamo
def forward(ctx, *flat_tensor_args):
nonlocal compiled_fw, compiled_bw, num_outs
# ctx.set_materialize_grads(False)
nonlocal compiled_fw, num_outs, saved_value_names, aot_decompositions, j_b
# Disable the JIT Autocast flag to prevent re-autocasting of jitted graph.
# TODO - Remove when https://github.com/pytorch/functorch/pull/794 is fixed.
old_jit_autocast_flag = torch._C._jit_set_autocast_mode(False)
# creating this to save the original inputs since the inputs might be returned as outs
# and would then have grad_fn set on them which is incorrect.
flat_tensor_args_0 = flat_tensor_args
if compiled_fw is None:
with preserve_rng_state():
# Set input tensors that require grad to leaves
# Detach to not accidentally extend the graph
flat_tensor_args = pytree.tree_map(
lambda x: x.detach().requires_grad_(x.requires_grad)
if isinstance(x, Tensor) else x, flat_tensor_args
Expand All @@ -184,8 +193,9 @@ def forward(ctx, *flat_tensor_args):
num_outs = 1

joint_inputs = (flat_tensor_args, out)
aot_decompositions = {**aot_autograd_decompositions, **decompositions}
with torch.set_grad_enabled(grad_state):
# This means the forward and backward graphs are created based on the input fn
# However we need to take in grad_out for the saved intermediates as well.
fx_g = make_fx(joint_forward_backward, aot_decompositions)(
*joint_inputs
)
Expand All @@ -196,33 +206,76 @@ def forward(ctx, *flat_tensor_args):
def fake_fn(primals, tangents):
return fx_g(primals, tangents)
fx_g = make_fx(functionalize(fake_fn))(*joint_inputs)
fw_module, bw_module = partition_fn(fx_g, joint_inputs)
# print(fw_module.code, bw_module.code)

fw_module, bw_module, saved_value_nodes = partition_fn(fx_g, joint_inputs)
saved_value_names = [node.name for node in saved_value_nodes]
compiled_fw = fw_compiler(fw_module, flat_tensor_args)
fw_outs = normalize_as_list(compiled_fw(*flat_tensor_args))

bw_args = fw_outs[num_outs:] + fw_outs[0:num_outs]
compiled_bw = bw_compiler(bw_module, bw_args)
j_b = create_joint_forward_backward(fw_module)
else:
fw_outs = normalize_as_list(compiled_fw(*flat_tensor_args))
ctx.num_intermediate = len(fw_outs[num_outs:])
to_be_saved = fw_outs[num_outs:] + list(flat_tensor_args_0)
ctx.save_for_backward(*to_be_saved)
torch._C._jit_set_autocast_mode(old_jit_autocast_flag)
ctx.save_for_backward(*fw_outs[num_outs:])
return tuple(fw_outs[0:num_outs])
return tuple(fw_outs)

@staticmethod
@disable_torchdynamo
def backward(ctx, *flat_args):
def backward(ctx, *flat_grad_outs):
# Disable the JIT Autocast flag to prevent re-autocasting of jitted graph.
# TODO - Remove when https://github.com/pytorch/functorch/pull/794 is fixed.
old_jit_autocast_flag = torch._C._jit_set_autocast_mode(False)
contiguous_args = [t.contiguous() for t in flat_args]
# contiguous_args = [t for t in flat_args]
out = normalize_as_list(compiled_bw(*ctx.saved_tensors, *contiguous_args))
nonlocal bw_modules, saved_value_names, num_outs, aot_decompositions, j_b
with preserve_rng_state():
intermediates = ctx.saved_tensors[:ctx.num_intermediate]
flat_tensor_args = ctx.saved_tensors[ctx.num_intermediate:]
flat_tensor_args = pytree.tree_map(
lambda x: x.detach().requires_grad_(x.requires_grad)
if isinstance(x, Tensor) else x, flat_tensor_args
)
inp_grad_outs = flat_grad_outs
with torch.set_grad_enabled(grad_state):
fx_g_b = make_fx(j_b, aot_decompositions)(flat_tensor_args, inp_grad_outs)
if config.use_functionalize:
# Functionalize the foward backward graph. First create a
# fake fn to make functionalize happy
def fake_fn(primals, tangents):
return fx_g_b(primals, tangents)
fx_g_b = make_fx(functionalize(fake_fn))(flat_tensor_args, inp_grad_outs)
saved_value_nodes = _get_saved_values(fx_g_b, saved_value_names)
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately this approach doesn't always work because the newly generated fx graph may not have the same nodes as the previous graph. We need an alternate way to select nodes of interest in this new graph!

assert len(saved_value_nodes) <= len(saved_value_names)
fw_module_b, bw_module_b, saved_values_new = _extract_fwd_bwd_modules(fx_g_b, saved_value_nodes)
if len(saved_values_new) != len(saved_value_names):
new_intermediates = []
# Forward saves more intermediates than needed
assert len(saved_values_new) < len(saved_value_names)
j = 0
for node in saved_values_new:
while node.name != saved_value_names[j]:
j += 1
new_intermediates.append(intermediates[j])
j += 1
intermediates = new_intermediates

# This is needed because aot function caching uses function id right now
bw_module_fn = None
for elem in bw_modules:
if elem.code == bw_module_b.code:
bw_module_fn = elem
break
if bw_module_fn is None:
bw_modules.append(bw_module_b)
bw_module_fn = bw_module_b

f = aot_function(bw_module_fn, bw_compiler, bw_compiler, partition_fn, aot_decompositions)
out = f(*intermediates, *inp_grad_outs)
torch._C._jit_set_autocast_mode(old_jit_autocast_flag)
return tuple(out)
return tuple(normalize_as_list(out))

return CompiledFunction
def return_fn(*args, **kwargs):
out = CompiledFunction.apply(*args, **kwargs)
return out[0:num_outs]
return return_fn


class _CompileCache(CompileCache):
Expand Down Expand Up @@ -312,7 +365,7 @@ def rearrange(tensor_args, static_args, static_argnums):
return args


KNOWN_TYPES = [torch.Tensor, int, str, float, bool]
KNOWN_TYPES = [torch.Tensor, int, str, float, bool, None]


def aot_function(
Expand Down Expand Up @@ -448,7 +501,6 @@ def returned_function(*args, **kwargs):
hasher_type,
*flat_args_for_cache,
)

# Compile the function and save it in the cache
if cached_res is None:
# Save the args_spec for flat_tensor_args to unflatten while tracing
Expand All @@ -473,7 +525,7 @@ def flat_fn(*flat_tensor_args):
for i in flat_out:
is_known_type = False
for j in KNOWN_TYPES:
if isinstance(i, j):
if j is None or isinstance(i, j):
is_known_type = True
break
if not is_known_type:
Expand All @@ -495,7 +547,7 @@ def flat_fn(*flat_tensor_args):
partition_fn,
decompositions,
grad_state=torch.is_grad_enabled(),
).apply
)
cached_res = (compiled_fn, out_spec)

# Save the compiled_fn in the cache
Expand Down Expand Up @@ -635,7 +687,7 @@ def aot_function_simplified(
partition_fn,
decompositions,
grad_state=torch.is_grad_enabled(),
).apply
)

return compiled_fn

Expand Down
21 changes: 19 additions & 2 deletions functorch/_src/partitioners.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,24 @@ def _extract_fwd_bwd_modules(joint_module: fx.GraphModule, saved_values):

fwd_module = fx.GraphModule(joint_module, fwd_graph)
bwd_module = fx.GraphModule(joint_module, bwd_graph)
return fwd_module, bwd_module
return fwd_module, bwd_module, saved_values


def _get_saved_values(new_module: fx.GraphModule, saved_value_names):
saved_values = []
for node in new_module.graph.nodes:
if node.name in saved_value_names:
if 'tensor_meta' not in node.meta and node.op == 'call_function':
users = node.users
assert all(user.target == operator.getitem for user in users)
for user in users:
saved_values.append(user)
else:
saved_values.append(node)

saved_values = list(saved_values)

return saved_values


def default_partition(
Expand Down Expand Up @@ -154,8 +171,8 @@ def default_partition(
saved_values.append(user)
else:
saved_values.append(node)
saved_values = list(set(saved_values))

saved_values = list(saved_values)
return _extract_fwd_bwd_modules(joint_module, saved_values)


Expand Down
Loading