Skip to content

Commit

Permalink
Separate forward and backwad compilation
Browse files Browse the repository at this point in the history
ghstack-source-id: 0b78895219a89ec3841cf8b0804e1be69bfeed8a
Pull Request resolved: #856
  • Loading branch information
anjali411 committed Jul 13, 2022
1 parent 44dd1bf commit 55baed5
Show file tree
Hide file tree
Showing 4 changed files with 218 additions and 82 deletions.
114 changes: 83 additions & 31 deletions functorch/_src/aot_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from functorch.experimental import functionalize
from . import config
from .decompositions import register_decomposition
from .partitioners import default_partition
from .partitioners import default_partition, _get_saved_values, _extract_fwd_bwd_modules
from .named_members_polyfill import _named_parameters, _named_buffers
from typing import Callable, List, Dict, Any, Tuple, Optional
from functools import wraps
Expand Down Expand Up @@ -70,7 +70,7 @@ def preserve_rng_state():

def create_joint_forward_backward(fn):
def joint_forward_backward(
primals: List[Any], tangents: List[Any]
primals: List[Any], cotangents: List[Any]
) -> Tuple[List[Any], List[Any]]:
# Call the forward pass
outs = fn(*primals)
Expand All @@ -84,21 +84,21 @@ def joint_forward_backward(
grad_primals.append(p)

# Get the outputs that need gradients
assert len(tangents) == len(outs)
assert len(cotangents) == len(outs)
needed_outs = []
needed_tangents = []
for out, tangent in zip(outs, tangents):
needed_cotangents = []
for out, cotangent in zip(outs, cotangents):
if isinstance(out, Tensor) and out.requires_grad:
needed_outs.append(out)
needed_tangents.append(tangent)
needed_cotangents.append(cotangent)
backward_out = []
# Call the backwards pass
if grad_primals:
backward_out = torch.autograd.grad(
needed_outs,
grad_primals,
grad_outputs=needed_tangents,
allow_unused=True,
grad_outputs=needed_cotangents,
allow_unused=True
)
backward_out_iter = iter(backward_out)
return outs, [
Expand Down Expand Up @@ -152,22 +152,31 @@ def create_aot_autograd_function(
if decompositions is None:
decompositions = {}
joint_forward_backward = create_joint_forward_backward(flat_fn)

# create_joint_forward_backward takes inputs and cotangents as inps
# inps: inputs, cotangents: flat_grad_outs
j_b = None
compiled_fw = None
compiled_bw = None
bw_modules = []
num_outs = None
saved_value_names = None
aot_decompositions = {**aot_autograd_decompositions, **decompositions}

class CompiledFunction(torch.autograd.Function):
@staticmethod
@disable_torchdynamo
def forward(ctx, *flat_tensor_args):
nonlocal compiled_fw, compiled_bw, num_outs
# ctx.set_materialize_grads(False)
nonlocal compiled_fw, num_outs, saved_value_names, aot_decompositions, j_b
# Disable the JIT Autocast flag to prevent re-autocasting of jitted graph.
# TODO - Remove when https://github.com/pytorch/functorch/pull/794 is fixed.
old_jit_autocast_flag = torch._C._jit_set_autocast_mode(False)
# creating this to save the original inputs since the inputs might be returned as outs
# and would then have grad_fn set on them which is incorrect.
flat_tensor_args_0 = flat_tensor_args
if compiled_fw is None:
with preserve_rng_state():
# Set input tensors that require grad to leaves
# Detach to not accidentally extend the graph
flat_tensor_args = pytree.tree_map(
lambda x: x.detach().requires_grad_(x.requires_grad)
if isinstance(x, Tensor) else x, flat_tensor_args
Expand All @@ -184,8 +193,9 @@ def forward(ctx, *flat_tensor_args):
num_outs = 1

joint_inputs = (flat_tensor_args, out)
aot_decompositions = {**aot_autograd_decompositions, **decompositions}
with torch.set_grad_enabled(grad_state):
# This means the forward and backward graphs are created based on the input fn
# However we need to take in grad_out for the saved intermediates as well.
fx_g = make_fx(joint_forward_backward, aot_decompositions)(
*joint_inputs
)
Expand All @@ -196,33 +206,76 @@ def forward(ctx, *flat_tensor_args):
def fake_fn(primals, tangents):
return fx_g(primals, tangents)
fx_g = make_fx(functionalize(fake_fn))(*joint_inputs)
fw_module, bw_module = partition_fn(fx_g, joint_inputs)
# print(fw_module.code, bw_module.code)

fw_module, bw_module, saved_value_nodes = partition_fn(fx_g, joint_inputs)
saved_value_names = [node.name for node in saved_value_nodes]
compiled_fw = fw_compiler(fw_module, flat_tensor_args)
fw_outs = normalize_as_list(compiled_fw(*flat_tensor_args))

bw_args = fw_outs[num_outs:] + fw_outs[0:num_outs]
compiled_bw = bw_compiler(bw_module, bw_args)
j_b = create_joint_forward_backward(fw_module)
else:
fw_outs = normalize_as_list(compiled_fw(*flat_tensor_args))
ctx.num_intermediate = len(fw_outs[num_outs:])
to_be_saved = fw_outs[num_outs:] + list(flat_tensor_args_0)
ctx.save_for_backward(*to_be_saved)
torch._C._jit_set_autocast_mode(old_jit_autocast_flag)
ctx.save_for_backward(*fw_outs[num_outs:])
return tuple(fw_outs[0:num_outs])
return tuple(fw_outs)

@staticmethod
@disable_torchdynamo
def backward(ctx, *flat_args):
def backward(ctx, *flat_grad_outs):
# Disable the JIT Autocast flag to prevent re-autocasting of jitted graph.
# TODO - Remove when https://github.com/pytorch/functorch/pull/794 is fixed.
old_jit_autocast_flag = torch._C._jit_set_autocast_mode(False)
contiguous_args = [t.contiguous() for t in flat_args]
# contiguous_args = [t for t in flat_args]
out = normalize_as_list(compiled_bw(*ctx.saved_tensors, *contiguous_args))
nonlocal bw_modules, saved_value_names, num_outs, aot_decompositions, j_b
with preserve_rng_state():
intermediates = ctx.saved_tensors[:ctx.num_intermediate]
flat_tensor_args = ctx.saved_tensors[ctx.num_intermediate:]
flat_tensor_args = pytree.tree_map(
lambda x: x.detach().requires_grad_(x.requires_grad)
if isinstance(x, Tensor) else x, flat_tensor_args
)
inp_grad_outs = flat_grad_outs
with torch.set_grad_enabled(grad_state):
fx_g_b = make_fx(j_b, aot_decompositions)(flat_tensor_args, inp_grad_outs)
if config.use_functionalize:
# Functionalize the foward backward graph. First create a
# fake fn to make functionalize happy
def fake_fn(primals, tangents):
return fx_g_b(primals, tangents)
fx_g_b = make_fx(functionalize(fake_fn))(flat_tensor_args, inp_grad_outs)
saved_value_nodes = _get_saved_values(fx_g_b, saved_value_names)
assert len(saved_value_nodes) <= len(saved_value_names)
fw_module_b, bw_module_b, saved_values_new = _extract_fwd_bwd_modules(fx_g_b, saved_value_nodes)
if len(saved_values_new) != len(saved_value_names):
new_intermediates = []
# Forward saves more intermediates than needed
assert len(saved_values_new) < len(saved_value_names)
j = 0
for node in saved_values_new:
while node.name != saved_value_names[j]:
j += 1
new_intermediates.append(intermediates[j])
j += 1
intermediates = new_intermediates

# This is needed because aot function caching uses function id right now
bw_module_fn = None
for elem in bw_modules:
if elem.code == bw_module_b.code:
bw_module_fn = elem
break
if bw_module_fn is None:
bw_modules.append(bw_module_b)
bw_module_fn = bw_module_b

f = aot_function(bw_module_fn, bw_compiler, bw_compiler, partition_fn, aot_decompositions)
out = f(*intermediates, *inp_grad_outs)
torch._C._jit_set_autocast_mode(old_jit_autocast_flag)
return tuple(out)
return tuple(normalize_as_list(out))

return CompiledFunction
def return_fn(*args, **kwargs):
out = CompiledFunction.apply(*args, **kwargs)
return out[0:num_outs]
return return_fn


class _CompileCache(CompileCache):
Expand Down Expand Up @@ -312,7 +365,7 @@ def rearrange(tensor_args, static_args, static_argnums):
return args


KNOWN_TYPES = [torch.Tensor, int, str, float, bool]
KNOWN_TYPES = [torch.Tensor, int, str, float, bool, None]


def aot_function(
Expand Down Expand Up @@ -448,7 +501,6 @@ def returned_function(*args, **kwargs):
hasher_type,
*flat_args_for_cache,
)

# Compile the function and save it in the cache
if cached_res is None:
# Save the args_spec for flat_tensor_args to unflatten while tracing
Expand All @@ -473,7 +525,7 @@ def flat_fn(*flat_tensor_args):
for i in flat_out:
is_known_type = False
for j in KNOWN_TYPES:
if isinstance(i, j):
if j is None or isinstance(i, j):
is_known_type = True
break
if not is_known_type:
Expand All @@ -495,7 +547,7 @@ def flat_fn(*flat_tensor_args):
partition_fn,
decompositions,
grad_state=torch.is_grad_enabled(),
).apply
)
cached_res = (compiled_fn, out_spec)

# Save the compiled_fn in the cache
Expand Down Expand Up @@ -635,7 +687,7 @@ def aot_function_simplified(
partition_fn,
decompositions,
grad_state=torch.is_grad_enabled(),
).apply
)

return compiled_fn

Expand Down
21 changes: 19 additions & 2 deletions functorch/_src/partitioners.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,24 @@ def _extract_fwd_bwd_modules(joint_module: fx.GraphModule, saved_values):

fwd_module = fx.GraphModule(joint_module, fwd_graph)
bwd_module = fx.GraphModule(joint_module, bwd_graph)
return fwd_module, bwd_module
return fwd_module, bwd_module, saved_values


def _get_saved_values(new_module: fx.GraphModule, saved_value_names):
saved_values = []
for node in new_module.graph.nodes:
if node.name in saved_value_names:
if 'tensor_meta' not in node.meta and node.op == 'call_function':
users = node.users
assert all(user.target == operator.getitem for user in users)
for user in users:
saved_values.append(user)
else:
saved_values.append(node)

saved_values = list(saved_values)

return saved_values


def default_partition(
Expand Down Expand Up @@ -154,8 +171,8 @@ def default_partition(
saved_values.append(user)
else:
saved_values.append(node)
saved_values = list(set(saved_values))

saved_values = list(saved_values)
return _extract_fwd_bwd_modules(joint_module, saved_values)


Expand Down
Loading

0 comments on commit 55baed5

Please sign in to comment.