Skip to content

Commit

Permalink
Merge pull request #47 from EleutherAI/feature/42-pp_data_comm_func
Browse files Browse the repository at this point in the history
[#42] pp data comm func (WIP)
  • Loading branch information
hyunwoongko authored Sep 30, 2022
2 parents 7af33eb + 8ab949d commit 6296c0e
Show file tree
Hide file tree
Showing 8 changed files with 604 additions and 400 deletions.
28 changes: 27 additions & 1 deletion oslo/torch/nn/parallel/pipeline_parallel/_buffers.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,29 @@
from oslo.torch.nn.parallel.pipeline_parallel._sync import (
register_location_for_forward_counter,
)


# original forward dictionary
_ORIGINAL_FORWARDS = dict()

# module device locations
_MODULE_DEVICE_LOCATIONS = dict()


def register_original_forward_function(location, func, device):
_ORIGINAL_FORWARDS[location] = func
_MODULE_DEVICE_LOCATIONS[location] = device
register_location_for_forward_counter(location)


def get_original_forward_function(location):
return _ORIGINAL_FORWARDS[location]


def get_module_device_location(location):
return _MODULE_DEVICE_LOCATIONS[location]


# Activations
_ACTIVATIONS = dict()

Expand All @@ -7,4 +33,4 @@ def save_activation(key, activation):


def pop_activation(key):
return _ACTIVATIONS.pop(key)
return _ACTIVATIONS.pop(key, []) # TODO; okay?
99 changes: 55 additions & 44 deletions oslo/torch/nn/parallel/pipeline_parallel/_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,91 +2,102 @@
from torch.cuda.amp import custom_fwd, custom_bwd
from torch.distributed import rpc

from oslo.torch.nn.parallel.pipeline_parallel._buffers import _ACTIVATIONS

_FORWARD_MARKER = set()

_LOCAL_BACKWARD_DONE = False

_NUM_BACKWARD_DONE = 0


def add_forward_marker(mark):
_FORWARD_MARKER.add(mark)


def remove_forward_marker(mark):
_FORWARD_MARKER.remove(mark)


def len_forward_marker():
return len(_FORWARD_MARKER)


def increase_num_backward_done():
global _NUM_BACKWARD_DONE
_NUM_BACKWARD_DONE += 1
from oslo.torch.nn.parallel.pipeline_parallel._buffers import (
get_original_forward_function,
save_activation,
pop_activation,
)
from oslo.torch.nn.parallel.pipeline_parallel._sync import (
register_job_requires_backward,
notify_backward_job_done,
)
from oslo.torch.nn.parallel.pipeline_parallel._messages import (
pack_tensor_stub,
unpack_tensor_stub,
)


def remote_module_forward(
caller,
location,
unique_key,
args_stub,
kwargs_stub,
requires_redirection,
is_training,
is_grad_enabled,
*tensors
):
if requires_redirection and is_training and is_grad_enabled:
# prepare backward redirection to caller
tensors = apply_backward_redirection(
caller,
unique_key,
*tensors,
)

(args, kwargs), _ = unpack_tensor_stub([args_stub, kwargs_stub], tensors)

def get_num_backward_done():
global _NUM_BACKWARD_DONE
return _NUM_BACKWARD_DONE
forward_fn = get_original_forward_function(location)
with torch.set_grad_enabled(is_grad_enabled):
result = forward_fn(*args, **kwargs)

result_stub, tensors = pack_tensor_stub(result, [])
need_activation_save = (
any([t.requires_grad for t in tensors]) and is_training and is_grad_enabled
)
if need_activation_save:
save_activation(unique_key, tensors)

def reset_num_backward_done():
global _NUM_BACKWARD_DONE, _LOCAL_BACKWARD_DONE
_NUM_BACKWARD_DONE = 0
_LOCAL_BACKWARD_DONE = False
return result_stub, tensors, need_activation_save


def launch_remote_backward(unique_key, *grad_outputs):
activation = _ACTIVATIONS.pop(unique_key)
activation = pop_activation(unique_key)

# TODO; some output contains tuple of tuple..
# better way to deal with this?
new_act = []
new_grad = []
for act, grad in zip(activation, grad_outputs):
if act is not None and grad is not None and act.requires_grad:
new_act.append(act)
new_grad.append(grad)

torch.autograd.backward(tuple(new_act), tuple(new_grad))
remove_forward_marker(unique_key)
if len(new_act) > 0 and len(new_grad) > 0:
torch.autograd.backward(tuple(new_act), tuple(new_grad))
notify_backward_job_done(unique_key)


# TODO; why
# why
# forward(ctx, req, *args, **kwargs)
# ...
# return args, kwargs
# does not work???
# ->
#
# because that is the design of Pytorch
# see: github.com/pytorch/pytorch/issues/16940
#
# based on https://github.com/facebookresearch/fairscale/blob/main/fairscale/nn/pipe/rpc.py#L53
class _PipeBackwardRedirection(torch.autograd.Function):
@staticmethod
@custom_fwd
def forward(ctx, to, unique_key, *args):
ctx.to = to
ctx.unique_key = unique_key
ctx.num_nones = 2 + len(args) # counting req
ctx.num_nones = 2 + len(args)

# mark
# TODO; do this before remote_forward
rpc.rpc_sync(to=to, func=add_forward_marker, args=(unique_key,))
# TODO; can we do this before remote_forward
# without rpc call?
rpc.rpc_sync(to=to, func=register_job_requires_backward, args=(unique_key,))

return args

@staticmethod
@custom_bwd
@rpc.functions.async_execution
def backward(ctx, *grad_outputs):
to = ctx.to
unique_key = ctx.unique_key

# print(f'backward: {to=}, {unique_key=}')

rpc.rpc_async(
to=to,
func=launch_remote_backward,
Expand Down
Loading

0 comments on commit 6296c0e

Please sign in to comment.