diff --git a/oslo/torch/nn/parallel/pipeline_parallel/_buffers.py b/oslo/torch/nn/parallel/pipeline_parallel/_buffers.py index 4f8f63c4..2ce852b4 100644 --- a/oslo/torch/nn/parallel/pipeline_parallel/_buffers.py +++ b/oslo/torch/nn/parallel/pipeline_parallel/_buffers.py @@ -1,3 +1,29 @@ +from oslo.torch.nn.parallel.pipeline_parallel._sync import ( + register_location_for_forward_counter, +) + + +# original forward dictionary +_ORIGINAL_FORWARDS = dict() + +# module device locations +_MODULE_DEVICE_LOCATIONS = dict() + + +def register_original_forward_function(location, func, device): + _ORIGINAL_FORWARDS[location] = func + _MODULE_DEVICE_LOCATIONS[location] = device + register_location_for_forward_counter(location) + + +def get_original_forward_function(location): + return _ORIGINAL_FORWARDS[location] + + +def get_module_device_location(location): + return _MODULE_DEVICE_LOCATIONS[location] + + # Activations _ACTIVATIONS = dict() @@ -7,4 +33,4 @@ def save_activation(key, activation): def pop_activation(key): - return _ACTIVATIONS.pop(key) + return _ACTIVATIONS.pop(key, []) # TODO; okay? diff --git a/oslo/torch/nn/parallel/pipeline_parallel/_functional.py b/oslo/torch/nn/parallel/pipeline_parallel/_functional.py index a8f5a977..48f42596 100644 --- a/oslo/torch/nn/parallel/pipeline_parallel/_functional.py +++ b/oslo/torch/nn/parallel/pipeline_parallel/_functional.py @@ -2,48 +2,59 @@ from torch.cuda.amp import custom_fwd, custom_bwd from torch.distributed import rpc -from oslo.torch.nn.parallel.pipeline_parallel._buffers import _ACTIVATIONS - -_FORWARD_MARKER = set() - -_LOCAL_BACKWARD_DONE = False - -_NUM_BACKWARD_DONE = 0 - - -def add_forward_marker(mark): - _FORWARD_MARKER.add(mark) - - -def remove_forward_marker(mark): - _FORWARD_MARKER.remove(mark) - - -def len_forward_marker(): - return len(_FORWARD_MARKER) - - -def increase_num_backward_done(): - global _NUM_BACKWARD_DONE - _NUM_BACKWARD_DONE += 1 +from oslo.torch.nn.parallel.pipeline_parallel._buffers import ( + get_original_forward_function, + save_activation, + pop_activation, +) +from oslo.torch.nn.parallel.pipeline_parallel._sync import ( + register_job_requires_backward, + notify_backward_job_done, +) +from oslo.torch.nn.parallel.pipeline_parallel._messages import ( + pack_tensor_stub, + unpack_tensor_stub, +) + + +def remote_module_forward( + caller, + location, + unique_key, + args_stub, + kwargs_stub, + requires_redirection, + is_training, + is_grad_enabled, + *tensors +): + if requires_redirection and is_training and is_grad_enabled: + # prepare backward redirection to caller + tensors = apply_backward_redirection( + caller, + unique_key, + *tensors, + ) + (args, kwargs), _ = unpack_tensor_stub([args_stub, kwargs_stub], tensors) -def get_num_backward_done(): - global _NUM_BACKWARD_DONE - return _NUM_BACKWARD_DONE + forward_fn = get_original_forward_function(location) + with torch.set_grad_enabled(is_grad_enabled): + result = forward_fn(*args, **kwargs) + result_stub, tensors = pack_tensor_stub(result, []) + need_activation_save = ( + any([t.requires_grad for t in tensors]) and is_training and is_grad_enabled + ) + if need_activation_save: + save_activation(unique_key, tensors) -def reset_num_backward_done(): - global _NUM_BACKWARD_DONE, _LOCAL_BACKWARD_DONE - _NUM_BACKWARD_DONE = 0 - _LOCAL_BACKWARD_DONE = False + return result_stub, tensors, need_activation_save def launch_remote_backward(unique_key, *grad_outputs): - activation = _ACTIVATIONS.pop(unique_key) + activation = pop_activation(unique_key) - # TODO; some output contains tuple of tuple.. - # better way to deal with this? new_act = [] new_grad = [] for act, grad in zip(activation, grad_outputs): @@ -51,18 +62,20 @@ def launch_remote_backward(unique_key, *grad_outputs): new_act.append(act) new_grad.append(grad) - torch.autograd.backward(tuple(new_act), tuple(new_grad)) - remove_forward_marker(unique_key) + if len(new_act) > 0 and len(new_grad) > 0: + torch.autograd.backward(tuple(new_act), tuple(new_grad)) + notify_backward_job_done(unique_key) -# TODO; why +# why # forward(ctx, req, *args, **kwargs) # ... # return args, kwargs # does not work??? -# -> +# # because that is the design of Pytorch # see: github.com/pytorch/pytorch/issues/16940 +# # based on https://github.com/facebookresearch/fairscale/blob/main/fairscale/nn/pipe/rpc.py#L53 class _PipeBackwardRedirection(torch.autograd.Function): @staticmethod @@ -70,23 +83,21 @@ class _PipeBackwardRedirection(torch.autograd.Function): def forward(ctx, to, unique_key, *args): ctx.to = to ctx.unique_key = unique_key - ctx.num_nones = 2 + len(args) # counting req + ctx.num_nones = 2 + len(args) # mark - # TODO; do this before remote_forward - rpc.rpc_sync(to=to, func=add_forward_marker, args=(unique_key,)) + # TODO; can we do this before remote_forward + # without rpc call? + rpc.rpc_sync(to=to, func=register_job_requires_backward, args=(unique_key,)) return args @staticmethod @custom_bwd - @rpc.functions.async_execution def backward(ctx, *grad_outputs): to = ctx.to unique_key = ctx.unique_key - # print(f'backward: {to=}, {unique_key=}') - rpc.rpc_async( to=to, func=launch_remote_backward, diff --git a/oslo/torch/nn/parallel/pipeline_parallel/_messages.py b/oslo/torch/nn/parallel/pipeline_parallel/_messages.py index f914f8db..e721f099 100644 --- a/oslo/torch/nn/parallel/pipeline_parallel/_messages.py +++ b/oslo/torch/nn/parallel/pipeline_parallel/_messages.py @@ -1,128 +1,135 @@ from dataclasses import dataclass -from typing import Optional, Any -from typing import Tuple, List, Union import torch -MESSAGE_GENERATION = 0 -REQUEST_GENERATION = 0 - - -@dataclass(init=False) -class Message: - comm_type: str - # 1. request or response - request_from: Optional[str] - # 2. request module id - exec_type: str - # 3. forward or backward - inputs: Optional[Any] - # 4. input data for module execution - outputs: Optional[Any] - # 5. output data from module execution - src_rank: int - # 6. source pp rank - dst_rank: int - # 7. destination pp rank - location: int - # 8. The location of the module within the module graph - in_autocast_context: bool - # 9. Whether the requester is currently in a autocast context - in_grad_related_context: bool - # 10. Whether the requester is currently in a no grad/enable grad context - use_activation_checkpointing: bool - # 11. Whether activation checkpointing is enabled for the current module - - def __init__(self): - global MESSAGE_GENERATION - MESSAGE_GENERATION += 1 - self.tag = MESSAGE_GENERATION +from oslo.torch.nn.parallel.pipeline_parallel._utils import ( + _is_namedtuple, + _is_private, + _is_primitive, +) @dataclass class TensorStub(object): - id: str - dtype: torch.dtype - shape: Union[List, Tuple] - requires_grad: bool - - -@dataclass(init=False) -class RemoteWorkRequest: - src: torch.device - dst: torch.device - location: str - tag: int - caller: str - keys: tuple - - def __init__(self): - global REQUEST_GENERATION - REQUEST_GENERATION += 1 - self.tag = REQUEST_GENERATION - - -def generate_request(src, dst, location, caller, args, kwargs): - req = RemoteWorkRequest() - req.src = src - req.dst = dst - req.location = location - req.caller = caller - - # merge kwargs into args - keys, new_args = assemble_args(args, kwargs) - req.keys = keys - - return req, new_args - - -def assemble_args(args, kwargs): - new_args = [] - keys = [] - for v in args: - if torch.is_tensor(v): - v = v.contiguous() - new_args.append(v) - keys.append(None) - - for k, v in kwargs.items(): - if k is None: - raise ValueError("None cannot be used the key of kwargs.") - if torch.is_tensor(v): - v = v.contiguous() - new_args.append(v) - keys.append(k) - - return tuple(keys), tuple(new_args) - - -def disassemble_new_args(new_args, keys): - args = list() - kwargs = dict() - - for k, v in zip(keys, new_args): - if k is None: - args.append(v) - else: - kwargs[k] = v - - return tuple(args), kwargs - - -def disassemble_result(result): - if isinstance(result, torch.Tensor): - args = (result,) - kwargs = dict() - wrapped = True - elif isinstance(result, dict): - args = tuple([]) - kwargs = result - wrapped = False - elif isinstance(result, (list, tuple)): - args = tuple(result) - kwargs = dict() - wrapped = False - else: - raise NotImplementedError - - return args, kwargs, wrapped + id: int + + +def pack_tensor_stub(obj, args_list): + """ + Recursively replace Tensor member variables to TensorStub. + Inspiration: https://github.com/pytorch/pytorch/blob/master/torch/distributed/utils.py#L48 + """ + if torch.is_tensor(obj): + id_ = len(args_list) + tensor_sub = TensorStub(id_) + args_list.append(obj) + obj = tensor_sub + + return obj, args_list + + elif _is_namedtuple(obj): + obj_list = list(obj) + for i in range(len(obj_list)): + obj_list_i, args_list = pack_tensor_stub(obj_list[i], args_list) + obj_list_i[i] = obj_list_i + obj = obj.__class__._make(obj_list) # use namedtuple's method + + return obj, args_list + + elif isinstance(obj, tuple): + obj = list(obj) + for i in range(len(obj)): + obj_i, args_list = pack_tensor_stub(obj[i], args_list) + obj[i] = obj_i + + obj = tuple(obj) + return obj, args_list + + elif isinstance(obj, list): + for i in range(len(obj)): + obj_i, args_list = pack_tensor_stub(obj[i], args_list) + obj[i] = obj_i + + return obj, args_list + + elif isinstance(obj, dict): + for k in obj.keys(): + obj_k, args_list = pack_tensor_stub(obj[k], args_list) + obj[k] = obj_k + + return obj, args_list + + elif _is_primitive(obj): + return obj, args_list + + else: # other kinds of object + members = [ + attr + for attr in dir(obj) + if not callable(getattr(obj, attr)) and not _is_private(attr) + ] + for m in members: + obj_m = getattr(obj, m) + obj_m, args_list = pack_tensor_stub(obj_m, args_list) + setattr(obj, m, obj_m) + + return obj, args_list + + +def unpack_tensor_stub(obj, args_list): + """ + Recursively replace TensorStub to original Tensor. + Inspiration: https://github.com/pytorch/pytorch/blob/master/torch/distributed/utils.py#L48 + """ + if isinstance(obj, TensorStub): + id_ = obj.id + tensor = args_list[id_] + return tensor, args_list + + elif _is_namedtuple(obj): + obj_list = list(obj) + for i in range(len(obj_list)): + obj_list_i, args_list = unpack_tensor_stub(obj_list[i], args_list) + obj_list_i[i] = obj_list_i + obj = obj.__class__._make(obj_list) # use namedtuple's method + + return obj, args_list + + elif isinstance(obj, tuple): + obj = list(obj) + for i in range(len(obj)): + obj_i, args_list = unpack_tensor_stub(obj[i], args_list) + obj[i] = obj_i + + obj = tuple(obj) + return obj, args_list + + elif isinstance(obj, list): + for i in range(len(obj)): + obj_i, args_list = unpack_tensor_stub(obj[i], args_list) + obj[i] = obj_i + + return obj, args_list + + elif isinstance(obj, dict): + for k in obj.keys(): + obj_k, args_list = unpack_tensor_stub(obj[k], args_list) + obj[k] = obj_k + + return obj, args_list + + elif _is_primitive(obj): + return obj, args_list + + else: # other kinds of object + members = [ + attr + for attr in dir(obj) + if not callable(getattr(obj, attr)) and not _is_private(attr) + ] + for m in members: + obj_m = getattr(obj, m) + obj_m, args_list = unpack_tensor_stub(obj_m, args_list) + setattr(obj, m, obj_m) + + return obj, args_list diff --git a/oslo/torch/nn/parallel/pipeline_parallel/_server.py b/oslo/torch/nn/parallel/pipeline_parallel/_server.py deleted file mode 100644 index ff8c6bc2..00000000 --- a/oslo/torch/nn/parallel/pipeline_parallel/_server.py +++ /dev/null @@ -1,141 +0,0 @@ -import time -from queue import PriorityQueue - -import torch - -from oslo.torch.nn.parallel.pipeline_parallel._buffers import ( - save_activation, - pop_activation, -) -from oslo.torch.nn.parallel.pipeline_parallel._functional import ( - apply_backward_redirection, -) -from oslo.torch.nn.parallel.pipeline_parallel._messages import disassemble_new_args - -# original forward dictionary -_ORIGINAL_FORWARDS = dict() - -# module device locations -_MODULE_DEVICE_LOCATIONS = dict() - - -# Job queue -_JOB_QUEUE = PriorityQueue() - -# remote work result receiver -_RECEIVER = dict() - - -_RESULT_DICT = dict() - - -_DONE_CHECKER = 0 - - -_FORWARD_COUNTER = dict() - - -def get_result(ind): - while ind not in _RESULT_DICT: - time.sleep(0.0) - return _RESULT_DICT[ind] - - -def reset_result(): - _RESULT_DICT.clear() - - -def get_forward_counter(loc): - return _FORWARD_COUNTER[loc] - - -def increment_forward_counter(loc): - _FORWARD_COUNTER[loc] += 1 - - -def reset_forward_counter(): - for k in _FORWARD_COUNTER: - _FORWARD_COUNTER[k] = 0 - - -def increment_done(): - global _DONE_CHECKER - _DONE_CHECKER += 1 - - -def get_done(): - global _DONE_CHECKER - return _DONE_CHECKER - - -def reset_done(): - global _DONE_CHECKER - _DONE_CHECKER = 0 - - -def reset_backward_notify(): - global _NOTIFY_BACKWARD_DONE - _NOTIFY_BACKWARD_DONE = False - - -def backward_done_notify(): - global _NOTIFY_BACKWARD_DONE - _NOTIFY_BACKWARD_DONE = True - - -def wait_backward_done(): - global _NOTIFY_BACKWARD_DONE - while not _NOTIFY_BACKWARD_DONE: - time.sleep(0.0) - - -def remote_module_forward(caller, location, unique_key, arg_keys, *args): - # prepare backward redirection to caller - args = apply_backward_redirection( - caller, - unique_key, - *args, - ) - - args, kwargs = disassemble_new_args(args, arg_keys) - forward_fn = _ORIGINAL_FORWARDS[location] - result = forward_fn(*args, **kwargs) - save_activation(unique_key, result) - return result - - -def wait_remote_work_result(request_message): - tag = request_message.tag - assert tag in _RECEIVER, f"{tag}" - result = _RECEIVER[tag].get() - torch.cuda.current_stream().synchronize() - - # delete a queue for communication - _RECEIVER.pop(tag) - return result - - -def response_with_result(req, tag, result, result_wrapped): - result = (req, result, result_wrapped) - _RECEIVER[tag].put(result) - torch.cuda.current_stream().synchronize() - - -def run_remote_backward(req, *grad_outputs): - # need to ensure that grad_outputs is fully received - # TODO; no other way? - torch.cuda.synchronize() - - tag = req.tag - activation, req = pop_activation(tag) - - # TODO; some output contains tuple of tuple.. - # better way to deal with this? - new_act = [] - new_grad = [] - for act, grad in zip(activation, grad_outputs): - if act is not None and grad is not None and act.requires_grad: - new_act.append(act) - new_grad.append(grad) - - torch.autograd.backward(tuple(new_act), tuple(new_grad)) diff --git a/oslo/torch/nn/parallel/pipeline_parallel/_sync.py b/oslo/torch/nn/parallel/pipeline_parallel/_sync.py new file mode 100644 index 00000000..a9d9cd9a --- /dev/null +++ b/oslo/torch/nn/parallel/pipeline_parallel/_sync.py @@ -0,0 +1,99 @@ +import time + +from torch.distributed import rpc + +from oslo.torch.distributed.parallel_mode import ParallelMode + + +# for watching whether every backward work is done or not +_JOBS_REQUIRE_BACKWARD = set() + + +def register_job_requires_backward(job_name): + _JOBS_REQUIRE_BACKWARD.add(job_name) + + +def notify_backward_job_done(job_name): + _JOBS_REQUIRE_BACKWARD.remove(job_name) + + +def get_num_jobs_require_backward_remaining(): + return len(_JOBS_REQUIRE_BACKWARD) + + +# for unique tag generation +_NUM_FORWARD_USED_COUNTER = dict() + + +def register_location_for_forward_counter(location): + _NUM_FORWARD_USED_COUNTER[location] = 0 + + +def make_unique_key(location, rank): + cnt = _NUM_FORWARD_USED_COUNTER[location] + unique_key = (location, cnt, rank) + _NUM_FORWARD_USED_COUNTER[location] += 1 + return unique_key + + +def reset_forward_used_counter(): + for k in _NUM_FORWARD_USED_COUNTER: + _NUM_FORWARD_USED_COUNTER[k] = 0 + + +# dictionary for result broadcast +_RESULT_DICT = dict() + +_RESULT_RECEIVED_MARKER = dict() + + +def set_result(ind, result): + _RESULT_DICT[ind] = result + _RESULT_RECEIVED_MARKER[ind] = True + + +def get_result(ind): + while ind not in _RESULT_RECEIVED_MARKER: + time.sleep(0.0) + return _RESULT_DICT[ind] + + +def reset_result(): + _RESULT_DICT.clear() + _RESULT_RECEIVED_MARKER.clear() + + +# +_CHECKER_BATCH_JOB_FINISHED = 0 + + +def notify_batch_job_finished(): + global _CHECKER_BATCH_JOB_FINISHED + _CHECKER_BATCH_JOB_FINISHED += 1 + + +def wait_other_ranks(rank, context): + global _CHECKER_BATCH_JOB_FINISHED + + # TODO; check the reason why we need this code block + # for checking batch job done. + # gradient computation goes wrong without this code + for other in context.get_ranks_in_group(ParallelMode.PIPELINE): + if other == rank: + notify_batch_job_finished() + else: + rpc_dst = context.get_pipeline_rpc_worker_name(other) + rpc.rpc_sync( + to=rpc_dst, + func=notify_batch_job_finished, + ) + + while _CHECKER_BATCH_JOB_FINISHED < context.get_world_size(ParallelMode.PIPELINE): + time.sleep(0.0) + + # every ranks done; reset + _CHECKER_BATCH_JOB_FINISHED = 0 + + # wait for all backward pass execution + while get_num_jobs_require_backward_remaining() != 0: + time.sleep(0.0) diff --git a/oslo/torch/nn/parallel/pipeline_parallel/_utils.py b/oslo/torch/nn/parallel/pipeline_parallel/_utils.py index b8bcdb26..61cc17ba 100644 --- a/oslo/torch/nn/parallel/pipeline_parallel/_utils.py +++ b/oslo/torch/nn/parallel/pipeline_parallel/_utils.py @@ -27,10 +27,17 @@ def post_order_traverse(node): yield node -def is_iterable(data): - try: - iter(data) - except TypeError: - return False - else: - return True +# from https://github.com/pytorch/pytorch/blob/master/torch/nn/parallel/scatter_gather.py#L12 +def _is_namedtuple(obj): + # Check if type was created from collections.namedtuple or a typing.NamedTuple. + return ( + isinstance(obj, tuple) and hasattr(obj, "_asdict") and hasattr(obj, "_fields") + ) + + +def _is_primitive(obj): + return not hasattr(obj, "__dict__") + + +def _is_private(attr): + return attr.startswith("__") diff --git a/oslo/torch/nn/parallel/pipeline_parallel/pipeline_parallel.py b/oslo/torch/nn/parallel/pipeline_parallel/pipeline_parallel.py index 3f67007d..389e77fe 100644 --- a/oslo/torch/nn/parallel/pipeline_parallel/pipeline_parallel.py +++ b/oslo/torch/nn/parallel/pipeline_parallel/pipeline_parallel.py @@ -1,5 +1,4 @@ import concurrent.futures -import time from threading import Lock import torch @@ -8,29 +7,29 @@ from oslo.torch.distributed.parallel_context import ParallelContext from oslo.torch.distributed.parallel_mode import ParallelMode -from oslo.torch.nn.parallel.pipeline_parallel._buffers import save_activation +from oslo.torch.nn.parallel.utils import get_parallel_context +from oslo.torch.nn.parallel.pipeline_parallel._buffers import ( + register_original_forward_function, + get_original_forward_function, + get_module_device_location, + save_activation, +) from oslo.torch.nn.parallel.pipeline_parallel._functional import ( + remote_module_forward, apply_backward_redirection, - len_forward_marker, ) -from oslo.torch.nn.parallel.pipeline_parallel._messages import assemble_args -from oslo.torch.nn.parallel.pipeline_parallel._model_partitioner import ModelPartitioner -from oslo.torch.nn.parallel.pipeline_parallel._server import ( - _ORIGINAL_FORWARDS, - _MODULE_DEVICE_LOCATIONS, - remote_module_forward, - _RESULT_DICT, +from oslo.torch.nn.parallel.pipeline_parallel._sync import ( + wait_other_ranks, + make_unique_key, + reset_forward_used_counter, + set_result, get_result, - reset_result, - increment_done, - get_done, - reset_done, - _FORWARD_COUNTER, - get_forward_counter, - increment_forward_counter, - reset_forward_counter, ) -from oslo.torch.nn.parallel.utils import get_parallel_context +from oslo.torch.nn.parallel.pipeline_parallel._messages import ( + pack_tensor_stub, + unpack_tensor_stub, +) +from oslo.torch.nn.parallel.pipeline_parallel._model_partitioner import ModelPartitioner def PipelineParallel( @@ -40,7 +39,21 @@ def PipelineParallel( num_micro_batches: int = 1, ): # TODO, @HG - pass + return _PipelineParallel( + module=module, + parallel_context=parallel_context, + memory_computation_balance=memory_computation_balance, + num_micro_batches=num_micro_batches, + ) + + +# function to launch self.module. needs this +# function because torch.set_grad_enabled() is +# thread local. +def launch(fn, is_grad_enabled, *args, **kwargs): + with torch.set_grad_enabled(is_grad_enabled): + result = fn(*args, **kwargs) + return result class _PipelineParallel(nn.Module): @@ -98,11 +111,12 @@ def __init__( def forward(self, *args, **kwargs): # set forward counter to zero - reset_forward_counter() + reset_forward_used_counter() # to ensure optimizer's step is done for all processes - # TODO; barrier for only PP - torch.distributed.barrier() + torch.distributed.barrier( + self.parallel_context.get_group(ParallelMode.PIPELINE) + ) if self.rank == 0: # TODO; @@ -129,22 +143,26 @@ def forward(self, *args, **kwargs): # futures.append(future) # ind += 1 + is_grad_enabled = torch.is_grad_enabled() for ind, (args_, kwargs_) in enumerate(zip(new_args, new_kwargs)): - future = self.producer.submit(self.module, *args_, **kwargs_) + future = self.producer.submit( + launch, self.module, is_grad_enabled, *args_, **kwargs_ + ) futures.append(future) for i, done in enumerate(concurrent.futures.as_completed(futures)): result = done.result() - _RESULT_DICT[i] = result - - # print(f'{i=}, {result.loss=}, {dist.get_rank()=}') + set_result(i, result) yield result else: + # TODO; the code block below does not make + # same number with rank 0. However, since + # this result is a dummy, it does not cause + # an error. # forward pass end, wait results from master for i in range(self.num_micro_batches): - # result = FINAL_RESULT_QUEUE.get() rpc_dst = self.parallel_context.get_pipeline_rpc_worker_name(0) result = rpc.rpc_sync( to=rpc_dst, @@ -152,29 +170,20 @@ def forward(self, *args, **kwargs): args=(i,), ) - yield result # has no gradient - - # barrier ? - # TODO; check the reason why we need this code block - for other in self.parallel_context.get_ranks_in_group(ParallelMode.PIPELINE): - if other == self.rank: - increment_done() - else: - rpc_dst = self.parallel_context.get_pipeline_rpc_worker_name(other) - rpc.rpc_sync( - to=rpc_dst, - func=increment_done, - ) - - while get_done() < self.parallel_context.get_world_size(ParallelMode.PIPELINE): - time.sleep(0.0) + # remove gradient of non-master result. + # without this, the users need to consider rank + # when calling loss.backward() + result, tensors = pack_tensor_stub(result, []) + for i_tensor in range(len(tensors)): + tensors[i_tensor].grad = None + result, _ = unpack_tensor_stub(result, tensors) - reset_done() - reset_result() + yield result - while len_forward_marker() != 0: - time.sleep(0.0) + # barrier; wait for all rank + wait_other_ranks(self.rank, self.parallel_context) + # TODO; seems like this is not necessary? torch.cuda.empty_cache() def _recursive_wrap(self, module, prefix): @@ -192,29 +201,34 @@ def _wrap_forward(self, module): loc = module.location device = module.oslo_parallel[ParallelMode.PIPELINE] - _ORIGINAL_FORWARDS[loc] = orig_forward - _MODULE_DEVICE_LOCATIONS[loc] = device - _FORWARD_COUNTER[loc] = 0 + register_original_forward_function(loc, orig_forward, device) def new_forward(*args, **kwargs): location = module.location - module_device = _MODULE_DEVICE_LOCATIONS[location] + module_device = get_module_device_location(location) module_device = torch.device("cuda", module_device) current_device = self.parallel_context.get_local_rank(ParallelMode.PIPELINE) current_device = torch.device("cuda", current_device) is_same = module_device == current_device if is_same: - forward_fn = _ORIGINAL_FORWARDS[location] + forward_fn = get_original_forward_function(location) result = forward_fn(*args, **kwargs) else: - arg_keys, new_args = assemble_args(args, kwargs) + (args_stub, kwargs_stub), tensors = pack_tensor_stub([args, kwargs], []) + tensors = tuple(tensors) + # does not save activation if the module is in eval mode + is_grad_enabled = torch.is_grad_enabled() + is_training = self.training + need_activation_save = any([t.requires_grad for t in tensors]) with self._lock: - cnt = get_forward_counter(location) - unique_key = (location, cnt) - increment_forward_counter(location) + unique_key = make_unique_key(location, self.rank) + + if need_activation_save and is_training and is_grad_enabled: + # prepare backward + save_activation(unique_key, tensors) caller = self.parallel_context.get_pipeline_rpc_worker_name( current_device.index @@ -223,23 +237,33 @@ def new_forward(*args, **kwargs): module_device.index ) - # prepare backward - save_activation(unique_key, new_args) - # request forward fut = rpc.rpc_async( to=callee, func=remote_module_forward, - args=(caller, location, unique_key, arg_keys) + new_args, + args=( + caller, + location, + unique_key, + args_stub, + kwargs_stub, + need_activation_save, + is_training, + is_grad_enabled, + ) + + tensors, ) - result = fut.wait() + # receive result as stub + result_stub, tensors, requires_redirection = fut.wait() - # TODO; does result always be an args? - result = apply_backward_redirection( - callee, - unique_key, - *result, - ) + if requires_redirection and is_training and is_grad_enabled: + tensors = apply_backward_redirection( + callee, + unique_key, + *tensors, + ) + + result, _ = unpack_tensor_stub(result_stub, tensors) return result diff --git a/tests/torch/nn/parallel/pipeline_parallel/test_pp.py b/tests/torch/nn/parallel/pipeline_parallel/test_pp.py index 7c0072b1..bbf46d58 100644 --- a/tests/torch/nn/parallel/pipeline_parallel/test_pp.py +++ b/tests/torch/nn/parallel/pipeline_parallel/test_pp.py @@ -1,19 +1,166 @@ from copy import deepcopy -import matplotlib -import matplotlib.pyplot as plt import torch -import torch.distributed as dist import torch.nn as nn -from datasets import load_dataset +import torch.distributed as dist from torch.distributed import rpc from torch.optim import Adam from torch.utils.data import DataLoader -from transformers import AutoTokenizer, GPT2Config, GPT2LMHeadModel, set_seed from oslo.torch.distributed import ParallelContext -from oslo.torch.nn.parallel import _PipelineParallel +from oslo.torch.nn.parallel import PipelineParallel from oslo.torch.nn.parallel.utils import allocate_params +from oslo.torch.nn.parallel.pipeline_parallel._buffers import _MODULE_DEVICE_LOCATIONS + +from datasets import load_dataset +from transformers import ( + AutoTokenizer, + GPT2Config, + GPT2LMHeadModel, + T5Config, + T5ForConditionalGeneration, + BartConfig, + BartForConditionalGeneration, + set_seed, +) + +import matplotlib +import matplotlib.pyplot as plt + + +# for debugging +from typing import Optional, Tuple, Union +from transformers.modeling_outputs import ( + BaseModelOutput, + Seq2SeqLMOutput, +) +from torch.nn import CrossEntropyLoss + + +class T5Debug(T5ForConditionalGeneration): + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.BoolTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + decoder_head_mask: Optional[torch.FloatTensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]: + + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + # Encode if needed (training, first prediction pass) + if encoder_outputs is None: + # Convert encoder inputs in embeddings if needed + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): + encoder_outputs = BaseModelOutput( + last_hidden_state=encoder_outputs[0], + hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, + attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, + ) + + hidden_states = encoder_outputs[0] + + if self.model_parallel: + torch.cuda.set_device(self.decoder.first_device) + + if ( + labels is not None + and decoder_input_ids is None + and decoder_inputs_embeds is None + ): + # get decoder inputs from shifting lm labels to the right + decoder_input_ids = self._shift_right(labels) + + # Set device for model parallelism + if self.model_parallel: + torch.cuda.set_device(self.decoder.first_device) + hidden_states = hidden_states.to(self.decoder.first_device) + if decoder_input_ids is not None: + decoder_input_ids = decoder_input_ids.to(self.decoder.first_device) + if attention_mask is not None: + attention_mask = attention_mask.to(self.decoder.first_device) + if decoder_attention_mask is not None: + decoder_attention_mask = decoder_attention_mask.to( + self.decoder.first_device + ) + + # Decode + decoder_outputs = self.decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + inputs_embeds=decoder_inputs_embeds, + past_key_values=past_key_values, + encoder_hidden_states=hidden_states, + encoder_attention_mask=attention_mask, + head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = decoder_outputs[0] + + # Set device for model parallelism + if self.model_parallel: + torch.cuda.set_device(self.encoder.first_device) + self.lm_head = self.lm_head.to(self.encoder.first_device) + sequence_output = sequence_output.to(self.lm_head.weight.device) + + if self.config.tie_word_embeddings: + # Rescale output before projecting on vocab + # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586 + sequence_output = sequence_output * (self.model_dim**-0.5) + + lm_logits = self.lm_head(sequence_output) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss(ignore_index=-100) + loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1)) + # TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666 + + if not return_dict: + output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs + return ((loss,) + output) if loss is not None else output + + return Seq2SeqLMOutput( + loss=loss, + logits=lm_logits, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + matplotlib.use("Agg") torch.autograd.set_detect_anomaly(True) @@ -21,18 +168,28 @@ parallel_context = ParallelContext.from_torch( data_parallel_size=1, - pipeline_parallel_size=2, + pipeline_parallel_size=4, tensor_parallel_size=1, ) current_device = torch.cuda.current_device() - -model_name = "gpt2" num_micro_batches = 8 -config = GPT2Config.from_pretrained(model_name) +# model_name = "gpt2" +# config = GPT2Config.from_pretrained(model_name) +# model = GPT2LMHeadModel(config) + +model_name = "t5-small" +config = T5Config.from_pretrained(model_name) +config.dropout_rate = 0.0 +model = T5ForConditionalGeneration(config) +# model = T5Debug(config) + +# model_name = "facebook/bart-base" +# config = BartConfig.from_pretrained(model_name) +# config.dropout_rate = 0. +# model = BartForConditionalGeneration(config) -model = GPT2LMHeadModel(config) for n, m in model.named_modules(): if isinstance(m, nn.Dropout): @@ -41,7 +198,7 @@ model_no_pp = deepcopy(model) model_no_pp.cuda() -wrapper_pp = _PipelineParallel( +wrapper_pp = PipelineParallel( model, parallel_context=parallel_context, memory_computation_balance=1.0, @@ -50,11 +207,27 @@ wrapper_pp.train() -optimizer_pp = Adam(wrapper_pp.parameters(), lr=3e-5) -optimizer_no_pp = Adam(model_no_pp.parameters(), lr=3e-5) +optimizer_pp = Adam(wrapper_pp.parameters(), lr=3e-2) +optimizer_no_pp = Adam(model_no_pp.parameters(), lr=3e-2) allocate_params(wrapper_pp, parallel_context) +# +# def print_location_backward_hook(m, i, o): +# print(f'{torch.distributed.get_rank()=}, {m.location=}') +# return i +# +# +# for name, m in wrapper_pp.named_modules(): +# m: nn.Module +# if hasattr(m, 'location'): +# m.register_full_backward_hook(print_location_backward_hook) +# +# +if torch.distributed.get_rank() == 1: + for k, v in _MODULE_DEVICE_LOCATIONS.items(): + print(f"{k}: {v}") + def run(): batch_size = 8 * num_micro_batches @@ -88,18 +261,16 @@ def run(): ): loss_pp = out_pp.loss loss_pp = loss_pp / num_micro_batches + loss_pp.backward() - if dist.get_rank() == 0: - loss_pp.backward() - - print(f"{ind}") cum_loss_pp += loss_pp.detach().item() out_no_pp = model_no_pp(**inputs, labels=inputs["input_ids"]) loss_no_pp = out_no_pp.loss loss_no_pp.backward() - print(f"{dist.get_rank()}, {cum_loss_pp}, {loss_no_pp}") + if dist.get_rank() == 0: + print(f"{dist.get_rank()=}, {cum_loss_pp=}, {loss_no_pp=}") optimizer_pp.step() optimizer_no_pp.step()