From 124c527fd55121869e54374aa76c61b4879238d3 Mon Sep 17 00:00:00 2001 From: ohwi Date: Mon, 26 Sep 2022 13:59:44 +0900 Subject: [PATCH 1/9] support HF T5 --- .../parallel/pipeline_parallel/_functional.py | 20 ++- .../parallel/pipeline_parallel/_messages.py | 13 +- .../nn/parallel/pipeline_parallel/_server.py | 20 ++- .../pipeline_parallel/pipeline_parallel.py | 75 ++++++-- .../nn/parallel/pipeline_parallel/test_pp.py | 170 +++++++++++++++++- 5 files changed, 269 insertions(+), 29 deletions(-) diff --git a/oslo/torch/nn/parallel/pipeline_parallel/_functional.py b/oslo/torch/nn/parallel/pipeline_parallel/_functional.py index 9a517519..b32d81c3 100644 --- a/oslo/torch/nn/parallel/pipeline_parallel/_functional.py +++ b/oslo/torch/nn/parallel/pipeline_parallel/_functional.py @@ -14,13 +14,16 @@ def add_forward_marker(mark): _FORWARD_MARKER.add(mark) + # print(f'ADD MARKER: {torch.distributed.get_rank()=}, {_FORWARD_MARKER=}') def remove_forward_marker(mark): _FORWARD_MARKER.remove(mark) + # print(f'REMOVE MARKER: {torch.distributed.get_rank()=}, {_FORWARD_MARKER=}') def len_forward_marker(): + # print(f'{torch.distributed.get_rank()=}, {_FORWARD_MARKER=}') return len(_FORWARD_MARKER) @@ -41,7 +44,17 @@ def reset_num_backward_done(): def launch_remote_backward(unique_key, *grad_outputs): - activation = _ACTIVATIONS.pop(unique_key) + activation = _ACTIVATIONS.pop(unique_key, []) # TODO; is this safe? + + # TODO; HF output... + if isinstance(activation, dict): + activation = tuple(activation.values()) + + # TODO; some activations are not tuple... WHY? + if not isinstance(activation, tuple): + activation = (activation, ) + + # print(f'{unique_key=}, {type(activation)=}, {len(activation)=}, {len(grad_outputs)=}') # TODO; some output contains tuple of tuple.. # better way to deal with this? @@ -74,7 +87,8 @@ def forward(ctx, to, unique_key, *args): ctx.num_nones = 2 + len(args) # counting req # mark - # TODO; do this before remote_forward + # TODO; can we do this before remote_forward + # without rpc call? rpc.rpc_sync(to=to, func=add_forward_marker, args=(unique_key,)) return args @@ -86,8 +100,6 @@ def backward(ctx, *grad_outputs): to = ctx.to unique_key = ctx.unique_key - # print(f'backward: {to=}, {unique_key=}') - rpc.rpc_async( to=to, func=launch_remote_backward, diff --git a/oslo/torch/nn/parallel/pipeline_parallel/_messages.py b/oslo/torch/nn/parallel/pipeline_parallel/_messages.py index f914f8db..ab035ee3 100644 --- a/oslo/torch/nn/parallel/pipeline_parallel/_messages.py +++ b/oslo/torch/nn/parallel/pipeline_parallel/_messages.py @@ -79,21 +79,32 @@ def generate_request(src, dst, location, caller, args, kwargs): def assemble_args(args, kwargs): new_args = [] keys = [] + requires_grads = [] for v in args: + requires_grad = False + if torch.is_tensor(v): v = v.contiguous() + requires_grad = v.requires_grad + new_args.append(v) keys.append(None) + requires_grads.append(requires_grad) for k, v in kwargs.items(): if k is None: raise ValueError("None cannot be used the key of kwargs.") + requires_grad = False + if torch.is_tensor(v): v = v.contiguous() + requires_grad = v.requires_grad + new_args.append(v) keys.append(k) + requires_grads.append(requires_grad) - return tuple(keys), tuple(new_args) + return tuple(keys), tuple(new_args), tuple(requires_grads) def disassemble_new_args(new_args, keys): diff --git a/oslo/torch/nn/parallel/pipeline_parallel/_server.py b/oslo/torch/nn/parallel/pipeline_parallel/_server.py index 6ace7605..4ade5967 100644 --- a/oslo/torch/nn/parallel/pipeline_parallel/_server.py +++ b/oslo/torch/nn/parallel/pipeline_parallel/_server.py @@ -16,6 +16,8 @@ # original forward dictionary _ORIGINAL_FORWARDS = dict() +_WRAPPED_FORWARDS = dict() + # module device locations _MODULE_DEVICE_LOCATIONS = dict() @@ -90,17 +92,21 @@ def wait_backward_done(): time.sleep(0.0) -def remote_module_forward(caller, location, unique_key, arg_keys, *args): - # prepare backward redirection to caller - args = apply_backward_redirection( - caller, - unique_key, - *args, - ) +def remote_module_forward(caller, location, unique_key, arg_keys, requires_redirection, *args): + if requires_redirection: + # prepare backward redirection to caller + args = apply_backward_redirection( + caller, + unique_key, + *args, + ) args, kwargs = disassemble_new_args(args, arg_keys) + forward_fn = _ORIGINAL_FORWARDS[location] + # forward_fn = _WRAPPED_FORWARDS[location] result = forward_fn(*args, **kwargs) + # TODO; check whether activation need to be saved save_activation(unique_key, result) return result diff --git a/oslo/torch/nn/parallel/pipeline_parallel/pipeline_parallel.py b/oslo/torch/nn/parallel/pipeline_parallel/pipeline_parallel.py index 7c2f84ca..af346e2e 100644 --- a/oslo/torch/nn/parallel/pipeline_parallel/pipeline_parallel.py +++ b/oslo/torch/nn/parallel/pipeline_parallel/pipeline_parallel.py @@ -19,6 +19,7 @@ from oslo.torch.nn.parallel.pipeline_parallel._messages import assemble_args from oslo.torch.nn.parallel.pipeline_parallel._server import ( _ORIGINAL_FORWARDS, + _WRAPPED_FORWARDS, _MODULE_DEVICE_LOCATIONS, remote_module_forward, _RESULT_DICT, @@ -204,13 +205,28 @@ def new_forward(*args, **kwargs): result = forward_fn(*args, **kwargs) else: - arg_keys, new_args = assemble_args(args, kwargs) + arg_keys, new_args, requires_grads = assemble_args(args, kwargs) + print(f'{new_args=}') + + need_activation_save = any(requires_grads) with self._lock: cnt = get_forward_counter(location) - unique_key = (location, cnt) + unique_key = (location, cnt, self.rank) increment_forward_counter(location) + if need_activation_save: + # prepare backward + save_activation(unique_key, new_args) + + # with self._lock: + # cnt = get_forward_counter(location) + # unique_key = (location, cnt) + # increment_forward_counter(location) + # + # # prepare backward + # save_activation(unique_key, new_args) + caller = self.parallel_context.get_pipeline_rpc_worker_name( current_device.index ) @@ -218,24 +234,59 @@ def new_forward(*args, **kwargs): module_device.index ) - # prepare backward - save_activation(unique_key, new_args) - # request forward fut = rpc.rpc_async( to=callee, func=remote_module_forward, - args=(caller, location, unique_key, arg_keys) + new_args, + args=(caller, location, unique_key, arg_keys, need_activation_save) + new_args, ) result = fut.wait() - # TODO; does result always be an args? - result = apply_backward_redirection( - callee, - unique_key, - *result, - ) + # TODO; does result always be an args? what if dict? + # HF output is OrderedDict!! + # need to deal with recursive case... + is_dict = False + orig_result = None + if isinstance(result, dict): + is_dict = True + orig_result = result + + # print(f'{result.keys()=}') + # print(f'{len(result)=}') + + result = tuple(result.values()) + + # print(f'{len(result)=}') + + wrapped = False + if not isinstance(result, tuple): + wrapped = True + result = (result, ) + + # re-check + requires_redirection = False + for x in result: + if torch.is_tensor(x): + if x.requires_grad: + requires_redirection = True + break + + if requires_redirection: + result = apply_backward_redirection( + callee, + unique_key, + *result, + ) + + if wrapped: + result = result[0] + + if is_dict: + for i, k in enumerate(orig_result.keys()): + orig_result[k] = result[i] + result = orig_result return result module.forward = new_forward + _WRAPPED_FORWARDS[loc] = new_forward diff --git a/tests/torch/nn/parallel/pipeline_parallel/test_pp.py b/tests/torch/nn/parallel/pipeline_parallel/test_pp.py index 38bcbb13..b4719a9a 100644 --- a/tests/torch/nn/parallel/pipeline_parallel/test_pp.py +++ b/tests/torch/nn/parallel/pipeline_parallel/test_pp.py @@ -10,13 +10,147 @@ from oslo.torch.distributed import ParallelContext from oslo.torch.nn.parallel import PipelineParallel from oslo.torch.nn.parallel.utils import allocate_params +from oslo.torch.nn.parallel.pipeline_parallel._server import _MODULE_DEVICE_LOCATIONS from datasets import load_dataset -from transformers import AutoTokenizer, GPT2Config, GPT2LMHeadModel, set_seed +from transformers import ( + AutoTokenizer, + GPT2Config, GPT2LMHeadModel, + T5Config, T5ForConditionalGeneration, + BartConfig, BartForConditionalGeneration, + set_seed, +) import matplotlib import matplotlib.pyplot as plt + +# for debugging +from typing import Optional, Tuple, Union +from transformers.modeling_outputs import ( + BaseModelOutput, + Seq2SeqLMOutput, +) +from torch.nn import CrossEntropyLoss + + +class T5Debug(T5ForConditionalGeneration): + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.BoolTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + decoder_head_mask: Optional[torch.FloatTensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]: + + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # Encode if needed (training, first prediction pass) + if encoder_outputs is None: + # Convert encoder inputs in embeddings if needed + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): + encoder_outputs = BaseModelOutput( + last_hidden_state=encoder_outputs[0], + hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, + attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, + ) + + hidden_states = encoder_outputs[0] + + if self.model_parallel: + torch.cuda.set_device(self.decoder.first_device) + + if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None: + # get decoder inputs from shifting lm labels to the right + decoder_input_ids = self._shift_right(labels) + + # Set device for model parallelism + if self.model_parallel: + torch.cuda.set_device(self.decoder.first_device) + hidden_states = hidden_states.to(self.decoder.first_device) + if decoder_input_ids is not None: + decoder_input_ids = decoder_input_ids.to(self.decoder.first_device) + if attention_mask is not None: + attention_mask = attention_mask.to(self.decoder.first_device) + if decoder_attention_mask is not None: + decoder_attention_mask = decoder_attention_mask.to(self.decoder.first_device) + + # Decode + decoder_outputs = self.decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + inputs_embeds=decoder_inputs_embeds, + past_key_values=past_key_values, + encoder_hidden_states=hidden_states, + encoder_attention_mask=attention_mask, + head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = decoder_outputs[0] + + # Set device for model parallelism + if self.model_parallel: + torch.cuda.set_device(self.encoder.first_device) + self.lm_head = self.lm_head.to(self.encoder.first_device) + sequence_output = sequence_output.to(self.lm_head.weight.device) + + if self.config.tie_word_embeddings: + # Rescale output before projecting on vocab + # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586 + sequence_output = sequence_output * (self.model_dim**-0.5) + + lm_logits = self.lm_head(sequence_output) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss(ignore_index=-100) + loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1)) + # TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666 + + if not return_dict: + output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs + return ((loss,) + output) if loss is not None else output + + return Seq2SeqLMOutput( + loss=loss, + logits=lm_logits, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + + matplotlib.use("Agg") torch.autograd.set_detect_anomaly(True) set_seed(42) @@ -28,13 +162,23 @@ ) current_device = torch.cuda.current_device() - -model_name = "gpt2" num_micro_batches = 8 -config = GPT2Config.from_pretrained(model_name) +# model_name = "gpt2" +# config = GPT2Config.from_pretrained(model_name) +# model = GPT2LMHeadModel(config) + +# model_name = "t5-small" +# config = T5Config.from_pretrained(model_name) +# config.dropout_rate = 0. +# model = T5ForConditionalGeneration(config) +# model = T5Debug(config) + +model_name = "facebook/bart-base" +config = BartConfig.from_pretrained(model_name) +config.dropout_rate = 0. +model = BartForConditionalGeneration(config) -model = GPT2LMHeadModel(config) for n, m in model.named_modules(): if isinstance(m, nn.Dropout): @@ -57,6 +201,22 @@ allocate_params(wrapper_pp, parallel_context) +# +# def print_location_backward_hook(m, i, o): +# print(f'{torch.distributed.get_rank()=}, {m.location=}') +# return i +# +# +# for name, m in wrapper_pp.named_modules(): +# m: nn.Module +# if hasattr(m, 'location'): +# m.register_full_backward_hook(print_location_backward_hook) +# +# +if torch.distributed.get_rank() == 1: + for k, v in _MODULE_DEVICE_LOCATIONS.items(): + print(f'{k}: {v}') + def run(): batch_size = 8 * num_micro_batches From f749c4352139cfb60db6d959dd2996ce716beb6b Mon Sep 17 00:00:00 2001 From: ohwi Date: Mon, 26 Sep 2022 14:07:54 +0900 Subject: [PATCH 2/9] Update test_pp.py --- .../nn/parallel/pipeline_parallel/test_pp.py | 188 ++++++++++++++++-- 1 file changed, 175 insertions(+), 13 deletions(-) diff --git a/tests/torch/nn/parallel/pipeline_parallel/test_pp.py b/tests/torch/nn/parallel/pipeline_parallel/test_pp.py index 7c0072b1..b4719a9a 100644 --- a/tests/torch/nn/parallel/pipeline_parallel/test_pp.py +++ b/tests/torch/nn/parallel/pipeline_parallel/test_pp.py @@ -1,19 +1,155 @@ from copy import deepcopy -import matplotlib -import matplotlib.pyplot as plt import torch -import torch.distributed as dist import torch.nn as nn -from datasets import load_dataset +import torch.distributed as dist from torch.distributed import rpc from torch.optim import Adam from torch.utils.data import DataLoader -from transformers import AutoTokenizer, GPT2Config, GPT2LMHeadModel, set_seed from oslo.torch.distributed import ParallelContext -from oslo.torch.nn.parallel import _PipelineParallel +from oslo.torch.nn.parallel import PipelineParallel from oslo.torch.nn.parallel.utils import allocate_params +from oslo.torch.nn.parallel.pipeline_parallel._server import _MODULE_DEVICE_LOCATIONS + +from datasets import load_dataset +from transformers import ( + AutoTokenizer, + GPT2Config, GPT2LMHeadModel, + T5Config, T5ForConditionalGeneration, + BartConfig, BartForConditionalGeneration, + set_seed, +) + +import matplotlib +import matplotlib.pyplot as plt + + +# for debugging +from typing import Optional, Tuple, Union +from transformers.modeling_outputs import ( + BaseModelOutput, + Seq2SeqLMOutput, +) +from torch.nn import CrossEntropyLoss + + +class T5Debug(T5ForConditionalGeneration): + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.BoolTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + decoder_head_mask: Optional[torch.FloatTensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]: + + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # Encode if needed (training, first prediction pass) + if encoder_outputs is None: + # Convert encoder inputs in embeddings if needed + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): + encoder_outputs = BaseModelOutput( + last_hidden_state=encoder_outputs[0], + hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, + attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, + ) + + hidden_states = encoder_outputs[0] + + if self.model_parallel: + torch.cuda.set_device(self.decoder.first_device) + + if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None: + # get decoder inputs from shifting lm labels to the right + decoder_input_ids = self._shift_right(labels) + + # Set device for model parallelism + if self.model_parallel: + torch.cuda.set_device(self.decoder.first_device) + hidden_states = hidden_states.to(self.decoder.first_device) + if decoder_input_ids is not None: + decoder_input_ids = decoder_input_ids.to(self.decoder.first_device) + if attention_mask is not None: + attention_mask = attention_mask.to(self.decoder.first_device) + if decoder_attention_mask is not None: + decoder_attention_mask = decoder_attention_mask.to(self.decoder.first_device) + + # Decode + decoder_outputs = self.decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + inputs_embeds=decoder_inputs_embeds, + past_key_values=past_key_values, + encoder_hidden_states=hidden_states, + encoder_attention_mask=attention_mask, + head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = decoder_outputs[0] + + # Set device for model parallelism + if self.model_parallel: + torch.cuda.set_device(self.encoder.first_device) + self.lm_head = self.lm_head.to(self.encoder.first_device) + sequence_output = sequence_output.to(self.lm_head.weight.device) + + if self.config.tie_word_embeddings: + # Rescale output before projecting on vocab + # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586 + sequence_output = sequence_output * (self.model_dim**-0.5) + + lm_logits = self.lm_head(sequence_output) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss(ignore_index=-100) + loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1)) + # TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666 + + if not return_dict: + output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs + return ((loss,) + output) if loss is not None else output + + return Seq2SeqLMOutput( + loss=loss, + logits=lm_logits, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + matplotlib.use("Agg") torch.autograd.set_detect_anomaly(True) @@ -26,13 +162,23 @@ ) current_device = torch.cuda.current_device() - -model_name = "gpt2" num_micro_batches = 8 -config = GPT2Config.from_pretrained(model_name) +# model_name = "gpt2" +# config = GPT2Config.from_pretrained(model_name) +# model = GPT2LMHeadModel(config) + +# model_name = "t5-small" +# config = T5Config.from_pretrained(model_name) +# config.dropout_rate = 0. +# model = T5ForConditionalGeneration(config) +# model = T5Debug(config) + +model_name = "facebook/bart-base" +config = BartConfig.from_pretrained(model_name) +config.dropout_rate = 0. +model = BartForConditionalGeneration(config) -model = GPT2LMHeadModel(config) for n, m in model.named_modules(): if isinstance(m, nn.Dropout): @@ -41,7 +187,7 @@ model_no_pp = deepcopy(model) model_no_pp.cuda() -wrapper_pp = _PipelineParallel( +wrapper_pp = PipelineParallel( model, parallel_context=parallel_context, memory_computation_balance=1.0, @@ -55,6 +201,22 @@ allocate_params(wrapper_pp, parallel_context) +# +# def print_location_backward_hook(m, i, o): +# print(f'{torch.distributed.get_rank()=}, {m.location=}') +# return i +# +# +# for name, m in wrapper_pp.named_modules(): +# m: nn.Module +# if hasattr(m, 'location'): +# m.register_full_backward_hook(print_location_backward_hook) +# +# +if torch.distributed.get_rank() == 1: + for k, v in _MODULE_DEVICE_LOCATIONS.items(): + print(f'{k}: {v}') + def run(): batch_size = 8 * num_micro_batches @@ -92,14 +254,14 @@ def run(): if dist.get_rank() == 0: loss_pp.backward() - print(f"{ind}") + print(f"{ind=}") cum_loss_pp += loss_pp.detach().item() out_no_pp = model_no_pp(**inputs, labels=inputs["input_ids"]) loss_no_pp = out_no_pp.loss loss_no_pp.backward() - print(f"{dist.get_rank()}, {cum_loss_pp}, {loss_no_pp}") + print(f"{dist.get_rank()=}, {cum_loss_pp=}, {loss_no_pp=}") optimizer_pp.step() optimizer_no_pp.step() From 9a8150d9a76b6289a4f9d00e1d0ab69079d50a70 Mon Sep 17 00:00:00 2001 From: ohwi Date: Mon, 26 Sep 2022 17:54:19 +0900 Subject: [PATCH 3/9] implement TensorStub --- .../parallel/pipeline_parallel/_functional.py | 16 +- .../parallel/pipeline_parallel/_messages.py | 254 +++++++++--------- .../nn/parallel/pipeline_parallel/_server.py | 91 +------ .../nn/parallel/pipeline_parallel/_utils.py | 21 +- .../pipeline_parallel/pipeline_parallel.py | 74 ++--- .../nn/parallel/pipeline_parallel/test_pp.py | 18 +- 6 files changed, 179 insertions(+), 295 deletions(-) diff --git a/oslo/torch/nn/parallel/pipeline_parallel/_functional.py b/oslo/torch/nn/parallel/pipeline_parallel/_functional.py index c9b8d475..b3eed56a 100644 --- a/oslo/torch/nn/parallel/pipeline_parallel/_functional.py +++ b/oslo/torch/nn/parallel/pipeline_parallel/_functional.py @@ -13,12 +13,12 @@ def add_forward_marker(mark): _FORWARD_MARKER.add(mark) - # print(f'ADD MARKER: {torch.distributed.get_rank()=}, {_FORWARD_MARKER=}') + print(f'ADD MARKER: {torch.distributed.get_rank()=}, {_FORWARD_MARKER=}') def remove_forward_marker(mark): _FORWARD_MARKER.remove(mark) - # print(f'REMOVE MARKER: {torch.distributed.get_rank()=}, {_FORWARD_MARKER=}') + print(f'REMOVE MARKER: {torch.distributed.get_rank()=}, {_FORWARD_MARKER=}') def len_forward_marker(): @@ -43,17 +43,9 @@ def reset_num_backward_done(): def launch_remote_backward(unique_key, *grad_outputs): - activation = _ACTIVATIONS.pop(unique_key, []) # TODO; is this safe? + activation = _ACTIVATIONS.pop(unique_key) - # TODO; HF output... - if isinstance(activation, dict): - activation = tuple(activation.values()) - - # TODO; some activations are not tuple... WHY? - if not isinstance(activation, tuple): - activation = (activation, ) - - # print(f'{unique_key=}, {type(activation)=}, {len(activation)=}, {len(grad_outputs)=}') + print(f'{unique_key=}, {type(activation)=}, {len(activation)=}, {len(grad_outputs)=}') # TODO; some output contains tuple of tuple.. # better way to deal with this? diff --git a/oslo/torch/nn/parallel/pipeline_parallel/_messages.py b/oslo/torch/nn/parallel/pipeline_parallel/_messages.py index ab035ee3..47625590 100644 --- a/oslo/torch/nn/parallel/pipeline_parallel/_messages.py +++ b/oslo/torch/nn/parallel/pipeline_parallel/_messages.py @@ -1,139 +1,131 @@ from dataclasses import dataclass -from typing import Optional, Any -from typing import Tuple, List, Union import torch -MESSAGE_GENERATION = 0 -REQUEST_GENERATION = 0 - - -@dataclass(init=False) -class Message: - comm_type: str - # 1. request or response - request_from: Optional[str] - # 2. request module id - exec_type: str - # 3. forward or backward - inputs: Optional[Any] - # 4. input data for module execution - outputs: Optional[Any] - # 5. output data from module execution - src_rank: int - # 6. source pp rank - dst_rank: int - # 7. destination pp rank - location: int - # 8. The location of the module within the module graph - in_autocast_context: bool - # 9. Whether the requester is currently in a autocast context - in_grad_related_context: bool - # 10. Whether the requester is currently in a no grad/enable grad context - use_activation_checkpointing: bool - # 11. Whether activation checkpointing is enabled for the current module - - def __init__(self): - global MESSAGE_GENERATION - MESSAGE_GENERATION += 1 - self.tag = MESSAGE_GENERATION +from oslo.torch.nn.parallel.pipeline_parallel._utils import ( + _is_namedtuple, _is_private, _is_primitive +) @dataclass class TensorStub(object): - id: str - dtype: torch.dtype - shape: Union[List, Tuple] - requires_grad: bool - - -@dataclass(init=False) -class RemoteWorkRequest: - src: torch.device - dst: torch.device - location: str - tag: int - caller: str - keys: tuple - - def __init__(self): - global REQUEST_GENERATION - REQUEST_GENERATION += 1 - self.tag = REQUEST_GENERATION - - -def generate_request(src, dst, location, caller, args, kwargs): - req = RemoteWorkRequest() - req.src = src - req.dst = dst - req.location = location - req.caller = caller - - # merge kwargs into args - keys, new_args = assemble_args(args, kwargs) - req.keys = keys - - return req, new_args - - -def assemble_args(args, kwargs): - new_args = [] - keys = [] - requires_grads = [] - for v in args: - requires_grad = False - - if torch.is_tensor(v): - v = v.contiguous() - requires_grad = v.requires_grad - - new_args.append(v) - keys.append(None) - requires_grads.append(requires_grad) - - for k, v in kwargs.items(): - if k is None: - raise ValueError("None cannot be used the key of kwargs.") - requires_grad = False - - if torch.is_tensor(v): - v = v.contiguous() - requires_grad = v.requires_grad - - new_args.append(v) - keys.append(k) - requires_grads.append(requires_grad) - - return tuple(keys), tuple(new_args), tuple(requires_grads) - - -def disassemble_new_args(new_args, keys): - args = list() - kwargs = dict() - - for k, v in zip(keys, new_args): - if k is None: - args.append(v) - else: - kwargs[k] = v - - return tuple(args), kwargs - - -def disassemble_result(result): - if isinstance(result, torch.Tensor): - args = (result,) - kwargs = dict() - wrapped = True - elif isinstance(result, dict): - args = tuple([]) - kwargs = result - wrapped = False - elif isinstance(result, (list, tuple)): - args = tuple(result) - kwargs = dict() - wrapped = False - else: - raise NotImplementedError - - return args, kwargs, wrapped + id: int + + +def pack_tensor_stub(obj, args_list): + """ + Recursively replace Tensor member variables to TensorStub. + Inspiration: https://github.com/pytorch/pytorch/blob/master/torch/distributed/utils.py#L48 + """ + if torch.is_tensor(obj): + id_ = len(args_list) + tensor_sub = TensorStub(id_) + args_list.append(obj) + obj = tensor_sub + + return obj, args_list + + elif _is_namedtuple(obj): + obj_list = list(obj) + for i in range(len(obj_list)): + obj_list_i, args_list = pack_tensor_stub(obj_list[i], args_list) + obj_list_i[i] = obj_list_i + obj = obj.__class__._make(obj_list) # use namedtuple's method + + return obj, args_list + + elif isinstance(obj, tuple): + obj = list(obj) + for i in range(len(obj)): + obj_i, args_list = pack_tensor_stub(obj[i], args_list) + obj[i] = obj_i + + obj = tuple(obj) + return obj, args_list + + elif isinstance(obj, list): + for i in range(len(obj)): + obj_i, args_list = pack_tensor_stub(obj[i], args_list) + obj[i] = obj_i + + return obj, args_list + + elif isinstance(obj, dict): + for k in obj.keys(): + obj_k, args_list = pack_tensor_stub(obj[k], args_list) + obj[k] = obj_k + + return obj, args_list + + elif _is_primitive(obj): + return obj, args_list + + else: # other kinds of object + members = [ + attr for attr in dir(obj) + if not callable(getattr(obj, attr)) and not _is_private(attr) + ] + for m in members: + obj_m = getattr(obj, m) + obj_m, args_list = pack_tensor_stub(obj_m, args_list) + setattr(obj, m, obj_m) + + return obj, args_list + + +def unpack_tensor_stub(obj, args_list): + """ + Recursively replace TensorStub to original Tensor. + Inspiration: https://github.com/pytorch/pytorch/blob/master/torch/distributed/utils.py#L48 + """ + if isinstance(obj, TensorStub): + id_ = obj.id + tensor = args_list[id_] + return tensor, args_list + + elif _is_namedtuple(obj): + obj_list = list(obj) + for i in range(len(obj_list)): + obj_list_i, args_list = unpack_tensor_stub(obj_list[i], args_list) + obj_list_i[i] = obj_list_i + obj = obj.__class__._make(obj_list) # use namedtuple's method + + return obj, args_list + + elif isinstance(obj, tuple): + obj = list(obj) + for i in range(len(obj)): + obj_i, args_list = unpack_tensor_stub(obj[i], args_list) + obj[i] = obj_i + + obj = tuple(obj) + return obj, args_list + + elif isinstance(obj, list): + for i in range(len(obj)): + obj_i, args_list = unpack_tensor_stub(obj[i], args_list) + obj[i] = obj_i + + return obj, args_list + + elif isinstance(obj, dict): + for k in obj.keys(): + obj_k, args_list = unpack_tensor_stub(obj[k], args_list) + obj[k] = obj_k + + return obj, args_list + + elif _is_primitive(obj): + return obj, args_list + + else: # other kinds of object + members = [ + attr for attr in dir(obj) + if not callable(getattr(obj, attr)) and not _is_private(attr) + ] + for m in members: + obj_m = getattr(obj, m) + obj_m, args_list = unpack_tensor_stub(obj_m, args_list) + setattr(obj, m, obj_m) + + return obj, args_list diff --git a/oslo/torch/nn/parallel/pipeline_parallel/_server.py b/oslo/torch/nn/parallel/pipeline_parallel/_server.py index a11113ab..b17d9a6a 100644 --- a/oslo/torch/nn/parallel/pipeline_parallel/_server.py +++ b/oslo/torch/nn/parallel/pipeline_parallel/_server.py @@ -1,33 +1,16 @@ import time -from queue import PriorityQueue -import torch - -from oslo.torch.nn.parallel.pipeline_parallel._buffers import ( - save_activation, - pop_activation, -) -from oslo.torch.nn.parallel.pipeline_parallel._functional import ( - apply_backward_redirection, -) -from oslo.torch.nn.parallel.pipeline_parallel._messages import disassemble_new_args +from oslo.torch.nn.parallel.pipeline_parallel._buffers import save_activation +from oslo.torch.nn.parallel.pipeline_parallel._functional import apply_backward_redirection +from oslo.torch.nn.parallel.pipeline_parallel._messages import pack_tensor_stub, unpack_tensor_stub # original forward dictionary _ORIGINAL_FORWARDS = dict() -_WRAPPED_FORWARDS = dict() - # module device locations _MODULE_DEVICE_LOCATIONS = dict() -# Job queue -_JOB_QUEUE = PriorityQueue() - -# remote work result receiver -_RECEIVER = dict() - - _RESULT_DICT = dict() @@ -75,73 +58,23 @@ def reset_done(): _DONE_CHECKER = 0 -def reset_backward_notify(): - global _NOTIFY_BACKWARD_DONE - _NOTIFY_BACKWARD_DONE = False - - -def backward_done_notify(): - global _NOTIFY_BACKWARD_DONE - _NOTIFY_BACKWARD_DONE = True - - -def wait_backward_done(): - global _NOTIFY_BACKWARD_DONE - while not _NOTIFY_BACKWARD_DONE: - time.sleep(0.0) - - -def remote_module_forward(caller, location, unique_key, arg_keys, requires_redirection, *args): +def remote_module_forward(caller, location, unique_key, args_stub, kwargs_stub, requires_redirection, *tensors): if requires_redirection: # prepare backward redirection to caller - args = apply_backward_redirection( + tensors = apply_backward_redirection( caller, unique_key, - *args, + *tensors, ) - args, kwargs = disassemble_new_args(args, arg_keys) + (args, kwargs), _ = unpack_tensor_stub([args_stub, kwargs_stub], tensors) forward_fn = _ORIGINAL_FORWARDS[location] - # forward_fn = _WRAPPED_FORWARDS[location] result = forward_fn(*args, **kwargs) - # TODO; check whether activation need to be saved - save_activation(unique_key, result) - return result - - -def wait_remote_work_result(request_message): - tag = request_message.tag - assert tag in _RECEIVER, f"{tag}" - result = _RECEIVER[tag].get() - torch.cuda.current_stream().synchronize() - - # delete a queue for communication - _RECEIVER.pop(tag) - return result - - -def response_with_result(req, tag, result, result_wrapped): - result = (req, result, result_wrapped) - _RECEIVER[tag].put(result) - torch.cuda.current_stream().synchronize() - - -def run_remote_backward(req, *grad_outputs): - # need to ensure that grad_outputs is fully received - # TODO; no other way? - torch.cuda.synchronize() - - tag = req.tag - activation, req = pop_activation(tag) - # TODO; some output contains tuple of tuple.. - # better way to deal with this? - new_act = [] - new_grad = [] - for act, grad in zip(activation, grad_outputs): - if act is not None and grad is not None and act.requires_grad: - new_act.append(act) - new_grad.append(grad) + result_stub, tensors = pack_tensor_stub(result, []) + need_activation_save = any([t.requires_grad for t in tensors]) + if need_activation_save: + save_activation(unique_key, tensors) - torch.autograd.backward(tuple(new_act), tuple(new_grad)) + return result_stub, tensors, need_activation_save diff --git a/oslo/torch/nn/parallel/pipeline_parallel/_utils.py b/oslo/torch/nn/parallel/pipeline_parallel/_utils.py index b8bcdb26..0869e2b1 100644 --- a/oslo/torch/nn/parallel/pipeline_parallel/_utils.py +++ b/oslo/torch/nn/parallel/pipeline_parallel/_utils.py @@ -27,10 +27,17 @@ def post_order_traverse(node): yield node -def is_iterable(data): - try: - iter(data) - except TypeError: - return False - else: - return True +# from https://github.com/pytorch/pytorch/blob/master/torch/nn/parallel/scatter_gather.py#L12 +def _is_namedtuple(obj): + # Check if type was created from collections.namedtuple or a typing.NamedTuple. + return ( + isinstance(obj, tuple) and hasattr(obj, "_asdict") and hasattr(obj, "_fields") + ) + + +def _is_primitive(obj): + return not hasattr(obj, '__dict__') + + +def _is_private(attr): + return attr.startswith('__') diff --git a/oslo/torch/nn/parallel/pipeline_parallel/pipeline_parallel.py b/oslo/torch/nn/parallel/pipeline_parallel/pipeline_parallel.py index 28d574b2..498363cd 100644 --- a/oslo/torch/nn/parallel/pipeline_parallel/pipeline_parallel.py +++ b/oslo/torch/nn/parallel/pipeline_parallel/pipeline_parallel.py @@ -13,11 +13,10 @@ apply_backward_redirection, len_forward_marker, ) -from oslo.torch.nn.parallel.pipeline_parallel._messages import assemble_args +from oslo.torch.nn.parallel.pipeline_parallel._messages import pack_tensor_stub, unpack_tensor_stub from oslo.torch.nn.parallel.pipeline_parallel._model_partitioner import ModelPartitioner from oslo.torch.nn.parallel.pipeline_parallel._server import ( _ORIGINAL_FORWARDS, - _WRAPPED_FORWARDS, _MODULE_DEVICE_LOCATIONS, remote_module_forward, _RESULT_DICT, @@ -41,7 +40,12 @@ def PipelineParallel( num_micro_batches: int = 1, ): # TODO, @HG - pass + return _PipelineParallel( + module=module, + parallel_context=parallel_context, + memory_computation_balance=memory_computation_balance, + num_micro_batches=num_micro_batches + ) class _PipelineParallel(nn.Module): @@ -210,11 +214,10 @@ def new_forward(*args, **kwargs): result = forward_fn(*args, **kwargs) else: - arg_keys, new_args, requires_grads = assemble_args(args, kwargs) + (args_stub, kwargs_stub), tensors = pack_tensor_stub([args, kwargs], []) + tensors = tuple(tensors) - print(f'{new_args=}') - - need_activation_save = any(requires_grads) + need_activation_save = any([t.requires_grad for t in tensors]) with self._lock: cnt = get_forward_counter(location) unique_key = (location, cnt, self.rank) @@ -222,15 +225,7 @@ def new_forward(*args, **kwargs): if need_activation_save: # prepare backward - save_activation(unique_key, new_args) - - # with self._lock: - # cnt = get_forward_counter(location) - # unique_key = (location, cnt) - # increment_forward_counter(location) - # - # # prepare backward - # save_activation(unique_key, new_args) + save_activation(unique_key, tensors) caller = self.parallel_context.get_pipeline_rpc_worker_name( current_device.index @@ -243,55 +238,20 @@ def new_forward(*args, **kwargs): fut = rpc.rpc_async( to=callee, func=remote_module_forward, - args=(caller, location, unique_key, arg_keys, need_activation_save) + new_args, + args=(caller, location, unique_key, args_stub, kwargs_stub, need_activation_save) + tensors, ) - result = fut.wait() - - # TODO; does result always be an args? what if dict? - # HF output is OrderedDict!! - # need to deal with recursive case... - is_dict = False - orig_result = None - if isinstance(result, dict): - is_dict = True - orig_result = result - - # print(f'{result.keys()=}') - # print(f'{len(result)=}') - - result = tuple(result.values()) - - # print(f'{len(result)=}') - - wrapped = False - if not isinstance(result, tuple): - wrapped = True - result = (result, ) - - # re-check - requires_redirection = False - for x in result: - if torch.is_tensor(x): - if x.requires_grad: - requires_redirection = True - break + # receive result as stub + result_stub, tensors, requires_redirection = fut.wait() if requires_redirection: - result = apply_backward_redirection( + tensors = apply_backward_redirection( callee, unique_key, - *result, + *tensors, ) - if wrapped: - result = result[0] - - if is_dict: - for i, k in enumerate(orig_result.keys()): - orig_result[k] = result[i] - result = orig_result + result, _ = unpack_tensor_stub(result_stub, tensors) return result module.forward = new_forward - _WRAPPED_FORWARDS[loc] = new_forward diff --git a/tests/torch/nn/parallel/pipeline_parallel/test_pp.py b/tests/torch/nn/parallel/pipeline_parallel/test_pp.py index b4719a9a..fc50a7ee 100644 --- a/tests/torch/nn/parallel/pipeline_parallel/test_pp.py +++ b/tests/torch/nn/parallel/pipeline_parallel/test_pp.py @@ -162,22 +162,22 @@ def forward( ) current_device = torch.cuda.current_device() -num_micro_batches = 8 +num_micro_batches = 1 # model_name = "gpt2" # config = GPT2Config.from_pretrained(model_name) # model = GPT2LMHeadModel(config) -# model_name = "t5-small" -# config = T5Config.from_pretrained(model_name) -# config.dropout_rate = 0. -# model = T5ForConditionalGeneration(config) +model_name = "t5-small" +config = T5Config.from_pretrained(model_name) +config.dropout_rate = 0. +model = T5ForConditionalGeneration(config) # model = T5Debug(config) -model_name = "facebook/bart-base" -config = BartConfig.from_pretrained(model_name) -config.dropout_rate = 0. -model = BartForConditionalGeneration(config) +# model_name = "facebook/bart-base" +# config = BartConfig.from_pretrained(model_name) +# config.dropout_rate = 0. +# model = BartForConditionalGeneration(config) for n, m in model.named_modules(): From 21c74f6e27f92fad82802aed0025b5ea6504c1c3 Mon Sep 17 00:00:00 2001 From: ohwi Date: Wed, 28 Sep 2022 18:54:17 +0900 Subject: [PATCH 4/9] clean PP code --- .../nn/parallel/pipeline_parallel/_buffers.py | 26 ++++- .../parallel/pipeline_parallel/_functional.py | 83 +++++++--------- .../nn/parallel/pipeline_parallel/_server.py | 80 --------------- .../nn/parallel/pipeline_parallel/_sync.py | 99 +++++++++++++++++++ .../pipeline_parallel/pipeline_parallel.py | 92 +++++++---------- .../nn/parallel/pipeline_parallel/test_pp.py | 18 ++-- 6 files changed, 206 insertions(+), 192 deletions(-) delete mode 100644 oslo/torch/nn/parallel/pipeline_parallel/_server.py create mode 100644 oslo/torch/nn/parallel/pipeline_parallel/_sync.py diff --git a/oslo/torch/nn/parallel/pipeline_parallel/_buffers.py b/oslo/torch/nn/parallel/pipeline_parallel/_buffers.py index 4f8f63c4..3cb04574 100644 --- a/oslo/torch/nn/parallel/pipeline_parallel/_buffers.py +++ b/oslo/torch/nn/parallel/pipeline_parallel/_buffers.py @@ -1,3 +1,27 @@ +from oslo.torch.nn.parallel.pipeline_parallel._sync import register_location_for_forward_counter + + +# original forward dictionary +_ORIGINAL_FORWARDS = dict() + +# module device locations +_MODULE_DEVICE_LOCATIONS = dict() + + +def register_original_forward_function(location, func, device): + _ORIGINAL_FORWARDS[location] = func + _MODULE_DEVICE_LOCATIONS[location] = device + register_location_for_forward_counter(location) + + +def get_original_forward_function(location): + return _ORIGINAL_FORWARDS[location] + + +def get_module_device_location(location): + return _MODULE_DEVICE_LOCATIONS[location] + + # Activations _ACTIVATIONS = dict() @@ -7,4 +31,4 @@ def save_activation(key, activation): def pop_activation(key): - return _ACTIVATIONS.pop(key) + return _ACTIVATIONS.pop(key, []) # TODO; okay? diff --git a/oslo/torch/nn/parallel/pipeline_parallel/_functional.py b/oslo/torch/nn/parallel/pipeline_parallel/_functional.py index b3eed56a..4a1dd670 100644 --- a/oslo/torch/nn/parallel/pipeline_parallel/_functional.py +++ b/oslo/torch/nn/parallel/pipeline_parallel/_functional.py @@ -2,53 +2,43 @@ from torch.cuda.amp import custom_fwd, custom_bwd from torch.distributed import rpc -from oslo.torch.nn.parallel.pipeline_parallel._buffers import _ACTIVATIONS - -_FORWARD_MARKER = set() - -_LOCAL_BACKWARD_DONE = False - -_NUM_BACKWARD_DONE = 0 - - -def add_forward_marker(mark): - _FORWARD_MARKER.add(mark) - print(f'ADD MARKER: {torch.distributed.get_rank()=}, {_FORWARD_MARKER=}') - - -def remove_forward_marker(mark): - _FORWARD_MARKER.remove(mark) - print(f'REMOVE MARKER: {torch.distributed.get_rank()=}, {_FORWARD_MARKER=}') - - -def len_forward_marker(): - # print(f'{torch.distributed.get_rank()=}, {_FORWARD_MARKER=}') - return len(_FORWARD_MARKER) - - -def increase_num_backward_done(): - global _NUM_BACKWARD_DONE - _NUM_BACKWARD_DONE += 1 +from oslo.torch.nn.parallel.pipeline_parallel._buffers import ( + get_original_forward_function, + save_activation, + pop_activation, +) +from oslo.torch.nn.parallel.pipeline_parallel._sync import ( + register_job_requires_backward, + notify_backward_job_done, +) +from oslo.torch.nn.parallel.pipeline_parallel._messages import pack_tensor_stub, unpack_tensor_stub + + +def remote_module_forward(caller, location, unique_key, args_stub, kwargs_stub, requires_redirection, *tensors): + if requires_redirection: + # prepare backward redirection to caller + tensors = apply_backward_redirection( + caller, + unique_key, + *tensors, + ) + (args, kwargs), _ = unpack_tensor_stub([args_stub, kwargs_stub], tensors) -def get_num_backward_done(): - global _NUM_BACKWARD_DONE - return _NUM_BACKWARD_DONE + forward_fn = get_original_forward_function(location) + result = forward_fn(*args, **kwargs) + result_stub, tensors = pack_tensor_stub(result, []) + need_activation_save = any([t.requires_grad for t in tensors]) + if need_activation_save: + save_activation(unique_key, tensors) -def reset_num_backward_done(): - global _NUM_BACKWARD_DONE, _LOCAL_BACKWARD_DONE - _NUM_BACKWARD_DONE = 0 - _LOCAL_BACKWARD_DONE = False + return result_stub, tensors, need_activation_save def launch_remote_backward(unique_key, *grad_outputs): - activation = _ACTIVATIONS.pop(unique_key) - - print(f'{unique_key=}, {type(activation)=}, {len(activation)=}, {len(grad_outputs)=}') + activation = pop_activation(unique_key) - # TODO; some output contains tuple of tuple.. - # better way to deal with this? new_act = [] new_grad = [] for act, grad in zip(activation, grad_outputs): @@ -56,18 +46,20 @@ def launch_remote_backward(unique_key, *grad_outputs): new_act.append(act) new_grad.append(grad) - torch.autograd.backward(tuple(new_act), tuple(new_grad)) - remove_forward_marker(unique_key) + if len(new_act) > 0 and len(new_grad) > 0: + torch.autograd.backward(tuple(new_act), tuple(new_grad)) + notify_backward_job_done(unique_key) -# TODO; why +# why # forward(ctx, req, *args, **kwargs) # ... # return args, kwargs # does not work??? -# -> +# # because that is the design of Pytorch # see: github.com/pytorch/pytorch/issues/16940 +# # based on https://github.com/facebookresearch/fairscale/blob/main/fairscale/nn/pipe/rpc.py#L53 class _PipeBackwardRedirection(torch.autograd.Function): @staticmethod @@ -75,18 +67,17 @@ class _PipeBackwardRedirection(torch.autograd.Function): def forward(ctx, to, unique_key, *args): ctx.to = to ctx.unique_key = unique_key - ctx.num_nones = 2 + len(args) # counting req + ctx.num_nones = 2 + len(args) # mark # TODO; can we do this before remote_forward # without rpc call? - rpc.rpc_sync(to=to, func=add_forward_marker, args=(unique_key,)) + rpc.rpc_sync(to=to, func=register_job_requires_backward, args=(unique_key,)) return args @staticmethod @custom_bwd - @rpc.functions.async_execution def backward(ctx, *grad_outputs): to = ctx.to unique_key = ctx.unique_key diff --git a/oslo/torch/nn/parallel/pipeline_parallel/_server.py b/oslo/torch/nn/parallel/pipeline_parallel/_server.py deleted file mode 100644 index b17d9a6a..00000000 --- a/oslo/torch/nn/parallel/pipeline_parallel/_server.py +++ /dev/null @@ -1,80 +0,0 @@ -import time - -from oslo.torch.nn.parallel.pipeline_parallel._buffers import save_activation -from oslo.torch.nn.parallel.pipeline_parallel._functional import apply_backward_redirection -from oslo.torch.nn.parallel.pipeline_parallel._messages import pack_tensor_stub, unpack_tensor_stub - -# original forward dictionary -_ORIGINAL_FORWARDS = dict() - -# module device locations -_MODULE_DEVICE_LOCATIONS = dict() - - -_RESULT_DICT = dict() - - -_DONE_CHECKER = 0 - - -_FORWARD_COUNTER = dict() - - -def get_result(ind): - while ind not in _RESULT_DICT: - time.sleep(0.0) - return _RESULT_DICT[ind] - - -def reset_result(): - _RESULT_DICT.clear() - - -def get_forward_counter(loc): - return _FORWARD_COUNTER[loc] - - -def increment_forward_counter(loc): - _FORWARD_COUNTER[loc] += 1 - - -def reset_forward_counter(): - for k in _FORWARD_COUNTER: - _FORWARD_COUNTER[k] = 0 - - -def increment_done(): - global _DONE_CHECKER - _DONE_CHECKER += 1 - - -def get_done(): - global _DONE_CHECKER - return _DONE_CHECKER - - -def reset_done(): - global _DONE_CHECKER - _DONE_CHECKER = 0 - - -def remote_module_forward(caller, location, unique_key, args_stub, kwargs_stub, requires_redirection, *tensors): - if requires_redirection: - # prepare backward redirection to caller - tensors = apply_backward_redirection( - caller, - unique_key, - *tensors, - ) - - (args, kwargs), _ = unpack_tensor_stub([args_stub, kwargs_stub], tensors) - - forward_fn = _ORIGINAL_FORWARDS[location] - result = forward_fn(*args, **kwargs) - - result_stub, tensors = pack_tensor_stub(result, []) - need_activation_save = any([t.requires_grad for t in tensors]) - if need_activation_save: - save_activation(unique_key, tensors) - - return result_stub, tensors, need_activation_save diff --git a/oslo/torch/nn/parallel/pipeline_parallel/_sync.py b/oslo/torch/nn/parallel/pipeline_parallel/_sync.py new file mode 100644 index 00000000..a9d9cd9a --- /dev/null +++ b/oslo/torch/nn/parallel/pipeline_parallel/_sync.py @@ -0,0 +1,99 @@ +import time + +from torch.distributed import rpc + +from oslo.torch.distributed.parallel_mode import ParallelMode + + +# for watching whether every backward work is done or not +_JOBS_REQUIRE_BACKWARD = set() + + +def register_job_requires_backward(job_name): + _JOBS_REQUIRE_BACKWARD.add(job_name) + + +def notify_backward_job_done(job_name): + _JOBS_REQUIRE_BACKWARD.remove(job_name) + + +def get_num_jobs_require_backward_remaining(): + return len(_JOBS_REQUIRE_BACKWARD) + + +# for unique tag generation +_NUM_FORWARD_USED_COUNTER = dict() + + +def register_location_for_forward_counter(location): + _NUM_FORWARD_USED_COUNTER[location] = 0 + + +def make_unique_key(location, rank): + cnt = _NUM_FORWARD_USED_COUNTER[location] + unique_key = (location, cnt, rank) + _NUM_FORWARD_USED_COUNTER[location] += 1 + return unique_key + + +def reset_forward_used_counter(): + for k in _NUM_FORWARD_USED_COUNTER: + _NUM_FORWARD_USED_COUNTER[k] = 0 + + +# dictionary for result broadcast +_RESULT_DICT = dict() + +_RESULT_RECEIVED_MARKER = dict() + + +def set_result(ind, result): + _RESULT_DICT[ind] = result + _RESULT_RECEIVED_MARKER[ind] = True + + +def get_result(ind): + while ind not in _RESULT_RECEIVED_MARKER: + time.sleep(0.0) + return _RESULT_DICT[ind] + + +def reset_result(): + _RESULT_DICT.clear() + _RESULT_RECEIVED_MARKER.clear() + + +# +_CHECKER_BATCH_JOB_FINISHED = 0 + + +def notify_batch_job_finished(): + global _CHECKER_BATCH_JOB_FINISHED + _CHECKER_BATCH_JOB_FINISHED += 1 + + +def wait_other_ranks(rank, context): + global _CHECKER_BATCH_JOB_FINISHED + + # TODO; check the reason why we need this code block + # for checking batch job done. + # gradient computation goes wrong without this code + for other in context.get_ranks_in_group(ParallelMode.PIPELINE): + if other == rank: + notify_batch_job_finished() + else: + rpc_dst = context.get_pipeline_rpc_worker_name(other) + rpc.rpc_sync( + to=rpc_dst, + func=notify_batch_job_finished, + ) + + while _CHECKER_BATCH_JOB_FINISHED < context.get_world_size(ParallelMode.PIPELINE): + time.sleep(0.0) + + # every ranks done; reset + _CHECKER_BATCH_JOB_FINISHED = 0 + + # wait for all backward pass execution + while get_num_jobs_require_backward_remaining() != 0: + time.sleep(0.0) diff --git a/oslo/torch/nn/parallel/pipeline_parallel/pipeline_parallel.py b/oslo/torch/nn/parallel/pipeline_parallel/pipeline_parallel.py index 498363cd..0f685945 100644 --- a/oslo/torch/nn/parallel/pipeline_parallel/pipeline_parallel.py +++ b/oslo/torch/nn/parallel/pipeline_parallel/pipeline_parallel.py @@ -8,29 +8,23 @@ from oslo.torch.distributed.parallel_context import ParallelContext from oslo.torch.distributed.parallel_mode import ParallelMode -from oslo.torch.nn.parallel.pipeline_parallel._buffers import save_activation -from oslo.torch.nn.parallel.pipeline_parallel._functional import ( - apply_backward_redirection, - len_forward_marker, +from oslo.torch.nn.parallel.utils import get_parallel_context +from oslo.torch.nn.parallel.pipeline_parallel._buffers import ( + register_original_forward_function, + get_original_forward_function, + get_module_device_location, + save_activation, ) -from oslo.torch.nn.parallel.pipeline_parallel._messages import pack_tensor_stub, unpack_tensor_stub -from oslo.torch.nn.parallel.pipeline_parallel._model_partitioner import ModelPartitioner -from oslo.torch.nn.parallel.pipeline_parallel._server import ( - _ORIGINAL_FORWARDS, - _MODULE_DEVICE_LOCATIONS, - remote_module_forward, - _RESULT_DICT, +from oslo.torch.nn.parallel.pipeline_parallel._functional import remote_module_forward, apply_backward_redirection +from oslo.torch.nn.parallel.pipeline_parallel._sync import ( + wait_other_ranks, + make_unique_key, + reset_forward_used_counter, + set_result, get_result, - reset_result, - increment_done, - get_done, - reset_done, - _FORWARD_COUNTER, - get_forward_counter, - increment_forward_counter, - reset_forward_counter, ) -from oslo.torch.nn.parallel.utils import get_parallel_context +from oslo.torch.nn.parallel.pipeline_parallel._messages import pack_tensor_stub, unpack_tensor_stub +from oslo.torch.nn.parallel.pipeline_parallel._model_partitioner import ModelPartitioner def PipelineParallel( @@ -103,11 +97,12 @@ def __init__( def forward(self, *args, **kwargs): # set forward counter to zero - reset_forward_counter() + reset_forward_used_counter() # to ensure optimizer's step is done for all processes - # TODO; barrier for only PP - torch.distributed.barrier() + torch.distributed.barrier( + self.parallel_context.get_group(ParallelMode.PIPELINE) + ) if self.rank == 0: # TODO; @@ -140,16 +135,17 @@ def forward(self, *args, **kwargs): for i, done in enumerate(concurrent.futures.as_completed(futures)): result = done.result() - _RESULT_DICT[i] = result - - # print(f'{i=}, {result.loss=}, {dist.get_rank()=}') + set_result(i, result) yield result else: + # TODO; the code block below does not make + # same number with rank 0. However, since + # this result is a dummy, it does not cause + # an error. # forward pass end, wait results from master for i in range(self.num_micro_batches): - # result = FINAL_RESULT_QUEUE.get() rpc_dst = self.parallel_context.get_pipeline_rpc_worker_name(0) result = rpc.rpc_sync( to=rpc_dst, @@ -157,28 +153,18 @@ def forward(self, *args, **kwargs): args=(i,), ) - yield result # has no gradient + # remove gradient of non-master result. + # without this, the users need to consider rank + # when calling loss.backward() + result, tensors = pack_tensor_stub(result, []) + for i_tensor in range(len(tensors)): + tensors[i_tensor].grad = None + result, _ = unpack_tensor_stub(result, tensors) - # barrier ? - # TODO; check the reason why we need this code block - for other in self.parallel_context.get_ranks_in_group(ParallelMode.PIPELINE): - if other == self.rank: - increment_done() - else: - rpc_dst = self.parallel_context.get_pipeline_rpc_worker_name(other) - rpc.rpc_sync( - to=rpc_dst, - func=increment_done, - ) - - while get_done() < self.parallel_context.get_world_size(ParallelMode.PIPELINE): - time.sleep(0.0) - - reset_done() - reset_result() + yield result - while len_forward_marker() != 0: - time.sleep(0.0) + # barrier; wait for all rank + wait_other_ranks(self.rank, self.parallel_context) torch.cuda.empty_cache() @@ -197,20 +183,18 @@ def _wrap_forward(self, module): loc = module.location device = module.oslo_parallel[ParallelMode.PIPELINE] - _ORIGINAL_FORWARDS[loc] = orig_forward - _MODULE_DEVICE_LOCATIONS[loc] = device - _FORWARD_COUNTER[loc] = 0 + register_original_forward_function(loc, orig_forward, device) def new_forward(*args, **kwargs): location = module.location - module_device = _MODULE_DEVICE_LOCATIONS[location] + module_device = get_module_device_location(location) module_device = torch.device("cuda", module_device) current_device = self.parallel_context.get_local_rank(ParallelMode.PIPELINE) current_device = torch.device("cuda", current_device) is_same = module_device == current_device if is_same: - forward_fn = _ORIGINAL_FORWARDS[location] + forward_fn = get_original_forward_function(location) result = forward_fn(*args, **kwargs) else: @@ -219,9 +203,7 @@ def new_forward(*args, **kwargs): need_activation_save = any([t.requires_grad for t in tensors]) with self._lock: - cnt = get_forward_counter(location) - unique_key = (location, cnt, self.rank) - increment_forward_counter(location) + unique_key = make_unique_key(location, self.rank) if need_activation_save: # prepare backward diff --git a/tests/torch/nn/parallel/pipeline_parallel/test_pp.py b/tests/torch/nn/parallel/pipeline_parallel/test_pp.py index fc50a7ee..a19cf10a 100644 --- a/tests/torch/nn/parallel/pipeline_parallel/test_pp.py +++ b/tests/torch/nn/parallel/pipeline_parallel/test_pp.py @@ -10,7 +10,7 @@ from oslo.torch.distributed import ParallelContext from oslo.torch.nn.parallel import PipelineParallel from oslo.torch.nn.parallel.utils import allocate_params -from oslo.torch.nn.parallel.pipeline_parallel._server import _MODULE_DEVICE_LOCATIONS +from oslo.torch.nn.parallel.pipeline_parallel._buffers import _MODULE_DEVICE_LOCATIONS from datasets import load_dataset from transformers import ( @@ -157,12 +157,12 @@ def forward( parallel_context = ParallelContext.from_torch( data_parallel_size=1, - pipeline_parallel_size=2, + pipeline_parallel_size=4, tensor_parallel_size=1, ) current_device = torch.cuda.current_device() -num_micro_batches = 1 +num_micro_batches = 8 # model_name = "gpt2" # config = GPT2Config.from_pretrained(model_name) @@ -196,8 +196,8 @@ def forward( wrapper_pp.train() -optimizer_pp = Adam(wrapper_pp.parameters(), lr=3e-5) -optimizer_no_pp = Adam(model_no_pp.parameters(), lr=3e-5) +optimizer_pp = Adam(wrapper_pp.parameters(), lr=3e-2) +optimizer_no_pp = Adam(model_no_pp.parameters(), lr=3e-2) allocate_params(wrapper_pp, parallel_context) @@ -250,18 +250,16 @@ def run(): ): loss_pp = out_pp.loss loss_pp = loss_pp / num_micro_batches + loss_pp.backward() - if dist.get_rank() == 0: - loss_pp.backward() - - print(f"{ind=}") cum_loss_pp += loss_pp.detach().item() out_no_pp = model_no_pp(**inputs, labels=inputs["input_ids"]) loss_no_pp = out_no_pp.loss loss_no_pp.backward() - print(f"{dist.get_rank()=}, {cum_loss_pp=}, {loss_no_pp=}") + if dist.get_rank() == 0: + print(f"{dist.get_rank()=}, {cum_loss_pp=}, {loss_no_pp=}") optimizer_pp.step() optimizer_no_pp.step() From 80676f43a402ff5119c0d2e7ce9a2d267e1d62c6 Mon Sep 17 00:00:00 2001 From: ohwi Date: Thu, 29 Sep 2022 14:48:50 +0900 Subject: [PATCH 5/9] remove activation save in eval mode --- .../nn/parallel/pipeline_parallel/_functional.py | 9 +++++++-- .../parallel/pipeline_parallel/pipeline_parallel.py | 11 ++++++++--- 2 files changed, 15 insertions(+), 5 deletions(-) diff --git a/oslo/torch/nn/parallel/pipeline_parallel/_functional.py b/oslo/torch/nn/parallel/pipeline_parallel/_functional.py index 4a1dd670..2a0edc6c 100644 --- a/oslo/torch/nn/parallel/pipeline_parallel/_functional.py +++ b/oslo/torch/nn/parallel/pipeline_parallel/_functional.py @@ -14,7 +14,12 @@ from oslo.torch.nn.parallel.pipeline_parallel._messages import pack_tensor_stub, unpack_tensor_stub -def remote_module_forward(caller, location, unique_key, args_stub, kwargs_stub, requires_redirection, *tensors): +def remote_module_forward( + caller, location, unique_key, + args_stub, kwargs_stub, + requires_redirection, training, + *tensors +): if requires_redirection: # prepare backward redirection to caller tensors = apply_backward_redirection( @@ -29,7 +34,7 @@ def remote_module_forward(caller, location, unique_key, args_stub, kwargs_stub, result = forward_fn(*args, **kwargs) result_stub, tensors = pack_tensor_stub(result, []) - need_activation_save = any([t.requires_grad for t in tensors]) + need_activation_save = any([t.requires_grad for t in tensors]) and training if need_activation_save: save_activation(unique_key, tensors) diff --git a/oslo/torch/nn/parallel/pipeline_parallel/pipeline_parallel.py b/oslo/torch/nn/parallel/pipeline_parallel/pipeline_parallel.py index 0f685945..74a33894 100644 --- a/oslo/torch/nn/parallel/pipeline_parallel/pipeline_parallel.py +++ b/oslo/torch/nn/parallel/pipeline_parallel/pipeline_parallel.py @@ -201,7 +201,8 @@ def new_forward(*args, **kwargs): (args_stub, kwargs_stub), tensors = pack_tensor_stub([args, kwargs], []) tensors = tuple(tensors) - need_activation_save = any([t.requires_grad for t in tensors]) + # does not save activation if the module is in eval mode + need_activation_save = any([t.requires_grad for t in tensors]) and self.training with self._lock: unique_key = make_unique_key(location, self.rank) @@ -220,12 +221,16 @@ def new_forward(*args, **kwargs): fut = rpc.rpc_async( to=callee, func=remote_module_forward, - args=(caller, location, unique_key, args_stub, kwargs_stub, need_activation_save) + tensors, + args=( + caller, location, unique_key, + args_stub, kwargs_stub, + need_activation_save, self.training + ) + tensors, ) # receive result as stub result_stub, tensors, requires_redirection = fut.wait() - if requires_redirection: + if requires_redirection and self.training: tensors = apply_backward_redirection( callee, unique_key, From 18eeffcd8affbab6cb8c3519d17fd705aca65f18 Mon Sep 17 00:00:00 2001 From: ohwi Date: Thu, 29 Sep 2022 15:49:31 +0900 Subject: [PATCH 6/9] support torch.no_grad() --- .../parallel/pipeline_parallel/_functional.py | 11 +++++--- .../pipeline_parallel/pipeline_parallel.py | 26 ++++++++++++++----- 2 files changed, 27 insertions(+), 10 deletions(-) diff --git a/oslo/torch/nn/parallel/pipeline_parallel/_functional.py b/oslo/torch/nn/parallel/pipeline_parallel/_functional.py index 2a0edc6c..d67315b3 100644 --- a/oslo/torch/nn/parallel/pipeline_parallel/_functional.py +++ b/oslo/torch/nn/parallel/pipeline_parallel/_functional.py @@ -17,10 +17,12 @@ def remote_module_forward( caller, location, unique_key, args_stub, kwargs_stub, - requires_redirection, training, + requires_redirection, + is_training, + is_grad_enabled, *tensors ): - if requires_redirection: + if requires_redirection and is_training and is_grad_enabled: # prepare backward redirection to caller tensors = apply_backward_redirection( caller, @@ -31,10 +33,11 @@ def remote_module_forward( (args, kwargs), _ = unpack_tensor_stub([args_stub, kwargs_stub], tensors) forward_fn = get_original_forward_function(location) - result = forward_fn(*args, **kwargs) + with torch.set_grad_enabled(is_grad_enabled): + result = forward_fn(*args, **kwargs) result_stub, tensors = pack_tensor_stub(result, []) - need_activation_save = any([t.requires_grad for t in tensors]) and training + need_activation_save = any([t.requires_grad for t in tensors]) and is_training and is_grad_enabled if need_activation_save: save_activation(unique_key, tensors) diff --git a/oslo/torch/nn/parallel/pipeline_parallel/pipeline_parallel.py b/oslo/torch/nn/parallel/pipeline_parallel/pipeline_parallel.py index 74a33894..f13acf92 100644 --- a/oslo/torch/nn/parallel/pipeline_parallel/pipeline_parallel.py +++ b/oslo/torch/nn/parallel/pipeline_parallel/pipeline_parallel.py @@ -1,5 +1,4 @@ import concurrent.futures -import time from threading import Lock import torch @@ -42,6 +41,15 @@ def PipelineParallel( ) +# function to launch self.module. needs this +# function because torch.set_grad_enabled() is +# thread local. +def launch(fn, is_grad_enabled, *args, **kwargs): + with torch.set_grad_enabled(is_grad_enabled): + result = fn(*args, **kwargs) + return result + + class _PipelineParallel(nn.Module): """ Pipeline parallel module @@ -129,8 +137,9 @@ def forward(self, *args, **kwargs): # futures.append(future) # ind += 1 + is_grad_enabled = torch.is_grad_enabled() for ind, (args_, kwargs_) in enumerate(zip(new_args, new_kwargs)): - future = self.producer.submit(self.module, *args_, **kwargs_) + future = self.producer.submit(launch, self.module, is_grad_enabled, *args_, **kwargs_) futures.append(future) for i, done in enumerate(concurrent.futures.as_completed(futures)): @@ -166,6 +175,7 @@ def forward(self, *args, **kwargs): # barrier; wait for all rank wait_other_ranks(self.rank, self.parallel_context) + # TODO; seems like this is not necessary? torch.cuda.empty_cache() def _recursive_wrap(self, module, prefix): @@ -202,11 +212,13 @@ def new_forward(*args, **kwargs): tensors = tuple(tensors) # does not save activation if the module is in eval mode - need_activation_save = any([t.requires_grad for t in tensors]) and self.training + is_grad_enabled = torch.is_grad_enabled() + is_training = self.training + need_activation_save = any([t.requires_grad for t in tensors]) with self._lock: unique_key = make_unique_key(location, self.rank) - if need_activation_save: + if need_activation_save and is_training and is_grad_enabled: # prepare backward save_activation(unique_key, tensors) @@ -224,13 +236,15 @@ def new_forward(*args, **kwargs): args=( caller, location, unique_key, args_stub, kwargs_stub, - need_activation_save, self.training + need_activation_save, + is_training, + is_grad_enabled, ) + tensors, ) # receive result as stub result_stub, tensors, requires_redirection = fut.wait() - if requires_redirection and self.training: + if requires_redirection and is_training and is_grad_enabled: tensors = apply_backward_redirection( callee, unique_key, From da6c087af5a3b2025bf88ff7e3469511863a6528 Mon Sep 17 00:00:00 2001 From: ohwi Date: Thu, 29 Sep 2022 17:24:32 +0900 Subject: [PATCH 7/9] support amp (WIP) --- .../parallel/pipeline_parallel/_functional.py | 11 +++++- .../pipeline_parallel/pipeline_parallel.py | 37 ++++++++++++++++-- .../nn/parallel/pipeline_parallel/test_pp.py | 39 +++++++++++++------ 3 files changed, 72 insertions(+), 15 deletions(-) diff --git a/oslo/torch/nn/parallel/pipeline_parallel/_functional.py b/oslo/torch/nn/parallel/pipeline_parallel/_functional.py index d67315b3..71460bf4 100644 --- a/oslo/torch/nn/parallel/pipeline_parallel/_functional.py +++ b/oslo/torch/nn/parallel/pipeline_parallel/_functional.py @@ -20,6 +20,7 @@ def remote_module_forward( requires_redirection, is_training, is_grad_enabled, + autocast_information, *tensors ): if requires_redirection and is_training and is_grad_enabled: @@ -33,7 +34,15 @@ def remote_module_forward( (args, kwargs), _ = unpack_tensor_stub([args_stub, kwargs_stub], tensors) forward_fn = get_original_forward_function(location) - with torch.set_grad_enabled(is_grad_enabled): + + is_autocast_enabled, is_autocast_cache_enabled, prev_fastdtype = autocast_information + autocast = torch.cuda.amp.autocast( + enabled=is_autocast_enabled, + dtype=prev_fastdtype, + cache_enabled=is_autocast_cache_enabled, + ) + + with torch.set_grad_enabled(is_grad_enabled), autocast: result = forward_fn(*args, **kwargs) result_stub, tensors = pack_tensor_stub(result, []) diff --git a/oslo/torch/nn/parallel/pipeline_parallel/pipeline_parallel.py b/oslo/torch/nn/parallel/pipeline_parallel/pipeline_parallel.py index f13acf92..1efcce1a 100644 --- a/oslo/torch/nn/parallel/pipeline_parallel/pipeline_parallel.py +++ b/oslo/torch/nn/parallel/pipeline_parallel/pipeline_parallel.py @@ -44,8 +44,14 @@ def PipelineParallel( # function to launch self.module. needs this # function because torch.set_grad_enabled() is # thread local. -def launch(fn, is_grad_enabled, *args, **kwargs): - with torch.set_grad_enabled(is_grad_enabled): +def launch(fn, is_grad_enabled, autocast_information, *args, **kwargs): + is_autocast_enabled, is_autocast_cache_enabled, prev_fastdtype = autocast_information + autocast = torch.cuda.amp.autocast( + enabled=is_autocast_enabled, + dtype=prev_fastdtype, + cache_enabled=is_autocast_cache_enabled, + ) + with torch.set_grad_enabled(is_grad_enabled), autocast: result = fn(*args, **kwargs) return result @@ -138,8 +144,23 @@ def forward(self, *args, **kwargs): # ind += 1 is_grad_enabled = torch.is_grad_enabled() + is_autocast_enabled = torch.is_autocast_enabled() + is_autocast_cache_enabled = torch.is_autocast_cache_enabled() + prev_fastdtype = torch.get_autocast_gpu_dtype() + autocast_information = ( + is_autocast_enabled, + is_autocast_cache_enabled, + prev_fastdtype, + ) for ind, (args_, kwargs_) in enumerate(zip(new_args, new_kwargs)): - future = self.producer.submit(launch, self.module, is_grad_enabled, *args_, **kwargs_) + future = self.producer.submit( + launch, + self.module, + is_grad_enabled, + autocast_information, + *args_, + **kwargs_ + ) futures.append(future) for i, done in enumerate(concurrent.futures.as_completed(futures)): @@ -229,6 +250,15 @@ def new_forward(*args, **kwargs): module_device.index ) + is_autocast_enabled = torch.is_autocast_enabled() + is_autocast_cache_enabled = torch.is_autocast_cache_enabled() + prev_fastdtype = torch.get_autocast_gpu_dtype() + autocast_information = ( + is_autocast_enabled, + is_autocast_cache_enabled, + prev_fastdtype, + ) + # request forward fut = rpc.rpc_async( to=callee, @@ -239,6 +269,7 @@ def new_forward(*args, **kwargs): need_activation_save, is_training, is_grad_enabled, + autocast_information, ) + tensors, ) # receive result as stub diff --git a/tests/torch/nn/parallel/pipeline_parallel/test_pp.py b/tests/torch/nn/parallel/pipeline_parallel/test_pp.py index a19cf10a..4cdc0939 100644 --- a/tests/torch/nn/parallel/pipeline_parallel/test_pp.py +++ b/tests/torch/nn/parallel/pipeline_parallel/test_pp.py @@ -3,6 +3,8 @@ import torch import torch.nn as nn import torch.distributed as dist +from torch.cuda.amp import autocast +from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler from torch.distributed import rpc from torch.optim import Adam from torch.utils.data import DataLoader @@ -157,7 +159,7 @@ def forward( parallel_context = ParallelContext.from_torch( data_parallel_size=1, - pipeline_parallel_size=4, + pipeline_parallel_size=8, tensor_parallel_size=1, ) @@ -199,6 +201,10 @@ def forward( optimizer_pp = Adam(wrapper_pp.parameters(), lr=3e-2) optimizer_no_pp = Adam(model_no_pp.parameters(), lr=3e-2) +# TODO; specify group +scaler_pp = ShardedGradScaler() +scaler_no_pp = ShardedGradScaler() + allocate_params(wrapper_pp, parallel_context) # @@ -224,14 +230,17 @@ def run(): tokenizer.pad_token = tokenizer.eos_token datasets = load_dataset("squad").data["train"]["context"] - datasets = [str(sample) for sample in datasets[:5000]] + datasets = [str(sample) for sample in datasets[:50000]] dataloader = DataLoader(datasets, batch_size=batch_size) pp_losses = [] no_pp_losses = [] + # wrapper_pp.eval() + # model_no_pp.eval() + with torch.enable_grad(): - # with torch.no_grad(): + # with torch.no_grad(): for i, data in enumerate(dataloader): inputs = tokenizer( data, @@ -244,25 +253,33 @@ def run(): optimizer_pp.zero_grad(set_to_none=True) optimizer_no_pp.zero_grad(set_to_none=True) + def result_generator(inputs): + with autocast(): + yield from wrapper_pp(**inputs, labels=inputs["input_ids"]) + cum_loss_pp = torch.zeros(1) - for ind, out_pp in enumerate( - wrapper_pp(**inputs, labels=inputs["input_ids"]) - ): + for ind, out_pp in enumerate(result_generator(inputs)): loss_pp = out_pp.loss loss_pp = loss_pp / num_micro_batches - loss_pp.backward() + scaler_pp.scale(loss_pp).backward() cum_loss_pp += loss_pp.detach().item() - out_no_pp = model_no_pp(**inputs, labels=inputs["input_ids"]) + with autocast(): + out_no_pp = model_no_pp(**inputs, labels=inputs["input_ids"]) loss_no_pp = out_no_pp.loss - loss_no_pp.backward() + scaler_no_pp.scale(loss_no_pp).backward() if dist.get_rank() == 0: print(f"{dist.get_rank()=}, {cum_loss_pp=}, {loss_no_pp=}") - optimizer_pp.step() - optimizer_no_pp.step() + # optimizer_pp.step() + # optimizer_no_pp.step() + + scaler_pp.step(optimizer_pp) + scaler_no_pp.step(optimizer_no_pp) + scaler_pp.update() + scaler_no_pp.update() pp_losses.append(cum_loss_pp.item()) no_pp_losses.append(loss_no_pp.item()) From 8d80aa76029a1d1521d5a3728dad6a4ba6660af7 Mon Sep 17 00:00:00 2001 From: ohwi Date: Fri, 30 Sep 2022 21:38:18 +0900 Subject: [PATCH 8/9] Revert "support amp (WIP)" This reverts commit da6c087af5a3b2025bf88ff7e3469511863a6528. --- .../parallel/pipeline_parallel/_functional.py | 11 +----- .../pipeline_parallel/pipeline_parallel.py | 37 ++---------------- .../nn/parallel/pipeline_parallel/test_pp.py | 39 ++++++------------- 3 files changed, 15 insertions(+), 72 deletions(-) diff --git a/oslo/torch/nn/parallel/pipeline_parallel/_functional.py b/oslo/torch/nn/parallel/pipeline_parallel/_functional.py index 71460bf4..d67315b3 100644 --- a/oslo/torch/nn/parallel/pipeline_parallel/_functional.py +++ b/oslo/torch/nn/parallel/pipeline_parallel/_functional.py @@ -20,7 +20,6 @@ def remote_module_forward( requires_redirection, is_training, is_grad_enabled, - autocast_information, *tensors ): if requires_redirection and is_training and is_grad_enabled: @@ -34,15 +33,7 @@ def remote_module_forward( (args, kwargs), _ = unpack_tensor_stub([args_stub, kwargs_stub], tensors) forward_fn = get_original_forward_function(location) - - is_autocast_enabled, is_autocast_cache_enabled, prev_fastdtype = autocast_information - autocast = torch.cuda.amp.autocast( - enabled=is_autocast_enabled, - dtype=prev_fastdtype, - cache_enabled=is_autocast_cache_enabled, - ) - - with torch.set_grad_enabled(is_grad_enabled), autocast: + with torch.set_grad_enabled(is_grad_enabled): result = forward_fn(*args, **kwargs) result_stub, tensors = pack_tensor_stub(result, []) diff --git a/oslo/torch/nn/parallel/pipeline_parallel/pipeline_parallel.py b/oslo/torch/nn/parallel/pipeline_parallel/pipeline_parallel.py index 1efcce1a..f13acf92 100644 --- a/oslo/torch/nn/parallel/pipeline_parallel/pipeline_parallel.py +++ b/oslo/torch/nn/parallel/pipeline_parallel/pipeline_parallel.py @@ -44,14 +44,8 @@ def PipelineParallel( # function to launch self.module. needs this # function because torch.set_grad_enabled() is # thread local. -def launch(fn, is_grad_enabled, autocast_information, *args, **kwargs): - is_autocast_enabled, is_autocast_cache_enabled, prev_fastdtype = autocast_information - autocast = torch.cuda.amp.autocast( - enabled=is_autocast_enabled, - dtype=prev_fastdtype, - cache_enabled=is_autocast_cache_enabled, - ) - with torch.set_grad_enabled(is_grad_enabled), autocast: +def launch(fn, is_grad_enabled, *args, **kwargs): + with torch.set_grad_enabled(is_grad_enabled): result = fn(*args, **kwargs) return result @@ -144,23 +138,8 @@ def forward(self, *args, **kwargs): # ind += 1 is_grad_enabled = torch.is_grad_enabled() - is_autocast_enabled = torch.is_autocast_enabled() - is_autocast_cache_enabled = torch.is_autocast_cache_enabled() - prev_fastdtype = torch.get_autocast_gpu_dtype() - autocast_information = ( - is_autocast_enabled, - is_autocast_cache_enabled, - prev_fastdtype, - ) for ind, (args_, kwargs_) in enumerate(zip(new_args, new_kwargs)): - future = self.producer.submit( - launch, - self.module, - is_grad_enabled, - autocast_information, - *args_, - **kwargs_ - ) + future = self.producer.submit(launch, self.module, is_grad_enabled, *args_, **kwargs_) futures.append(future) for i, done in enumerate(concurrent.futures.as_completed(futures)): @@ -250,15 +229,6 @@ def new_forward(*args, **kwargs): module_device.index ) - is_autocast_enabled = torch.is_autocast_enabled() - is_autocast_cache_enabled = torch.is_autocast_cache_enabled() - prev_fastdtype = torch.get_autocast_gpu_dtype() - autocast_information = ( - is_autocast_enabled, - is_autocast_cache_enabled, - prev_fastdtype, - ) - # request forward fut = rpc.rpc_async( to=callee, @@ -269,7 +239,6 @@ def new_forward(*args, **kwargs): need_activation_save, is_training, is_grad_enabled, - autocast_information, ) + tensors, ) # receive result as stub diff --git a/tests/torch/nn/parallel/pipeline_parallel/test_pp.py b/tests/torch/nn/parallel/pipeline_parallel/test_pp.py index 4cdc0939..a19cf10a 100644 --- a/tests/torch/nn/parallel/pipeline_parallel/test_pp.py +++ b/tests/torch/nn/parallel/pipeline_parallel/test_pp.py @@ -3,8 +3,6 @@ import torch import torch.nn as nn import torch.distributed as dist -from torch.cuda.amp import autocast -from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler from torch.distributed import rpc from torch.optim import Adam from torch.utils.data import DataLoader @@ -159,7 +157,7 @@ def forward( parallel_context = ParallelContext.from_torch( data_parallel_size=1, - pipeline_parallel_size=8, + pipeline_parallel_size=4, tensor_parallel_size=1, ) @@ -201,10 +199,6 @@ def forward( optimizer_pp = Adam(wrapper_pp.parameters(), lr=3e-2) optimizer_no_pp = Adam(model_no_pp.parameters(), lr=3e-2) -# TODO; specify group -scaler_pp = ShardedGradScaler() -scaler_no_pp = ShardedGradScaler() - allocate_params(wrapper_pp, parallel_context) # @@ -230,17 +224,14 @@ def run(): tokenizer.pad_token = tokenizer.eos_token datasets = load_dataset("squad").data["train"]["context"] - datasets = [str(sample) for sample in datasets[:50000]] + datasets = [str(sample) for sample in datasets[:5000]] dataloader = DataLoader(datasets, batch_size=batch_size) pp_losses = [] no_pp_losses = [] - # wrapper_pp.eval() - # model_no_pp.eval() - with torch.enable_grad(): - # with torch.no_grad(): + # with torch.no_grad(): for i, data in enumerate(dataloader): inputs = tokenizer( data, @@ -253,33 +244,25 @@ def run(): optimizer_pp.zero_grad(set_to_none=True) optimizer_no_pp.zero_grad(set_to_none=True) - def result_generator(inputs): - with autocast(): - yield from wrapper_pp(**inputs, labels=inputs["input_ids"]) - cum_loss_pp = torch.zeros(1) - for ind, out_pp in enumerate(result_generator(inputs)): + for ind, out_pp in enumerate( + wrapper_pp(**inputs, labels=inputs["input_ids"]) + ): loss_pp = out_pp.loss loss_pp = loss_pp / num_micro_batches - scaler_pp.scale(loss_pp).backward() + loss_pp.backward() cum_loss_pp += loss_pp.detach().item() - with autocast(): - out_no_pp = model_no_pp(**inputs, labels=inputs["input_ids"]) + out_no_pp = model_no_pp(**inputs, labels=inputs["input_ids"]) loss_no_pp = out_no_pp.loss - scaler_no_pp.scale(loss_no_pp).backward() + loss_no_pp.backward() if dist.get_rank() == 0: print(f"{dist.get_rank()=}, {cum_loss_pp=}, {loss_no_pp=}") - # optimizer_pp.step() - # optimizer_no_pp.step() - - scaler_pp.step(optimizer_pp) - scaler_no_pp.step(optimizer_no_pp) - scaler_pp.update() - scaler_no_pp.update() + optimizer_pp.step() + optimizer_no_pp.step() pp_losses.append(cum_loss_pp.item()) no_pp_losses.append(loss_no_pp.item()) From 8ab949dae6d27342ae12a7260da6c85d08239202 Mon Sep 17 00:00:00 2001 From: ohwi Date: Fri, 30 Sep 2022 21:49:13 +0900 Subject: [PATCH 9/9] pre-commit --- .../nn/parallel/pipeline_parallel/_buffers.py | 6 ++-- .../parallel/pipeline_parallel/_functional.py | 16 +++++++--- .../parallel/pipeline_parallel/_messages.py | 14 +++++--- .../nn/parallel/pipeline_parallel/_utils.py | 4 +-- .../pipeline_parallel/pipeline_parallel.py | 32 +++++++++++++------ .../nn/parallel/pipeline_parallel/test_pp.py | 27 +++++++++++----- 6 files changed, 68 insertions(+), 31 deletions(-) diff --git a/oslo/torch/nn/parallel/pipeline_parallel/_buffers.py b/oslo/torch/nn/parallel/pipeline_parallel/_buffers.py index 3cb04574..2ce852b4 100644 --- a/oslo/torch/nn/parallel/pipeline_parallel/_buffers.py +++ b/oslo/torch/nn/parallel/pipeline_parallel/_buffers.py @@ -1,4 +1,6 @@ -from oslo.torch.nn.parallel.pipeline_parallel._sync import register_location_for_forward_counter +from oslo.torch.nn.parallel.pipeline_parallel._sync import ( + register_location_for_forward_counter, +) # original forward dictionary @@ -31,4 +33,4 @@ def save_activation(key, activation): def pop_activation(key): - return _ACTIVATIONS.pop(key, []) # TODO; okay? + return _ACTIVATIONS.pop(key, []) # TODO; okay? diff --git a/oslo/torch/nn/parallel/pipeline_parallel/_functional.py b/oslo/torch/nn/parallel/pipeline_parallel/_functional.py index d67315b3..48f42596 100644 --- a/oslo/torch/nn/parallel/pipeline_parallel/_functional.py +++ b/oslo/torch/nn/parallel/pipeline_parallel/_functional.py @@ -11,12 +11,18 @@ register_job_requires_backward, notify_backward_job_done, ) -from oslo.torch.nn.parallel.pipeline_parallel._messages import pack_tensor_stub, unpack_tensor_stub +from oslo.torch.nn.parallel.pipeline_parallel._messages import ( + pack_tensor_stub, + unpack_tensor_stub, +) def remote_module_forward( - caller, location, unique_key, - args_stub, kwargs_stub, + caller, + location, + unique_key, + args_stub, + kwargs_stub, requires_redirection, is_training, is_grad_enabled, @@ -37,7 +43,9 @@ def remote_module_forward( result = forward_fn(*args, **kwargs) result_stub, tensors = pack_tensor_stub(result, []) - need_activation_save = any([t.requires_grad for t in tensors]) and is_training and is_grad_enabled + need_activation_save = ( + any([t.requires_grad for t in tensors]) and is_training and is_grad_enabled + ) if need_activation_save: save_activation(unique_key, tensors) diff --git a/oslo/torch/nn/parallel/pipeline_parallel/_messages.py b/oslo/torch/nn/parallel/pipeline_parallel/_messages.py index 47625590..e721f099 100644 --- a/oslo/torch/nn/parallel/pipeline_parallel/_messages.py +++ b/oslo/torch/nn/parallel/pipeline_parallel/_messages.py @@ -3,7 +3,9 @@ import torch from oslo.torch.nn.parallel.pipeline_parallel._utils import ( - _is_namedtuple, _is_private, _is_primitive + _is_namedtuple, + _is_private, + _is_primitive, ) @@ -30,7 +32,7 @@ def pack_tensor_stub(obj, args_list): for i in range(len(obj_list)): obj_list_i, args_list = pack_tensor_stub(obj_list[i], args_list) obj_list_i[i] = obj_list_i - obj = obj.__class__._make(obj_list) # use namedtuple's method + obj = obj.__class__._make(obj_list) # use namedtuple's method return obj, args_list @@ -60,9 +62,10 @@ def pack_tensor_stub(obj, args_list): elif _is_primitive(obj): return obj, args_list - else: # other kinds of object + else: # other kinds of object members = [ - attr for attr in dir(obj) + attr + for attr in dir(obj) if not callable(getattr(obj, attr)) and not _is_private(attr) ] for m in members: @@ -120,7 +123,8 @@ def unpack_tensor_stub(obj, args_list): else: # other kinds of object members = [ - attr for attr in dir(obj) + attr + for attr in dir(obj) if not callable(getattr(obj, attr)) and not _is_private(attr) ] for m in members: diff --git a/oslo/torch/nn/parallel/pipeline_parallel/_utils.py b/oslo/torch/nn/parallel/pipeline_parallel/_utils.py index 0869e2b1..61cc17ba 100644 --- a/oslo/torch/nn/parallel/pipeline_parallel/_utils.py +++ b/oslo/torch/nn/parallel/pipeline_parallel/_utils.py @@ -36,8 +36,8 @@ def _is_namedtuple(obj): def _is_primitive(obj): - return not hasattr(obj, '__dict__') + return not hasattr(obj, "__dict__") def _is_private(attr): - return attr.startswith('__') + return attr.startswith("__") diff --git a/oslo/torch/nn/parallel/pipeline_parallel/pipeline_parallel.py b/oslo/torch/nn/parallel/pipeline_parallel/pipeline_parallel.py index f13acf92..389e77fe 100644 --- a/oslo/torch/nn/parallel/pipeline_parallel/pipeline_parallel.py +++ b/oslo/torch/nn/parallel/pipeline_parallel/pipeline_parallel.py @@ -14,7 +14,10 @@ get_module_device_location, save_activation, ) -from oslo.torch.nn.parallel.pipeline_parallel._functional import remote_module_forward, apply_backward_redirection +from oslo.torch.nn.parallel.pipeline_parallel._functional import ( + remote_module_forward, + apply_backward_redirection, +) from oslo.torch.nn.parallel.pipeline_parallel._sync import ( wait_other_ranks, make_unique_key, @@ -22,7 +25,10 @@ set_result, get_result, ) -from oslo.torch.nn.parallel.pipeline_parallel._messages import pack_tensor_stub, unpack_tensor_stub +from oslo.torch.nn.parallel.pipeline_parallel._messages import ( + pack_tensor_stub, + unpack_tensor_stub, +) from oslo.torch.nn.parallel.pipeline_parallel._model_partitioner import ModelPartitioner @@ -37,7 +43,7 @@ def PipelineParallel( module=module, parallel_context=parallel_context, memory_computation_balance=memory_computation_balance, - num_micro_batches=num_micro_batches + num_micro_batches=num_micro_batches, ) @@ -139,7 +145,9 @@ def forward(self, *args, **kwargs): is_grad_enabled = torch.is_grad_enabled() for ind, (args_, kwargs_) in enumerate(zip(new_args, new_kwargs)): - future = self.producer.submit(launch, self.module, is_grad_enabled, *args_, **kwargs_) + future = self.producer.submit( + launch, self.module, is_grad_enabled, *args_, **kwargs_ + ) futures.append(future) for i, done in enumerate(concurrent.futures.as_completed(futures)): @@ -234,12 +242,16 @@ def new_forward(*args, **kwargs): to=callee, func=remote_module_forward, args=( - caller, location, unique_key, - args_stub, kwargs_stub, - need_activation_save, - is_training, - is_grad_enabled, - ) + tensors, + caller, + location, + unique_key, + args_stub, + kwargs_stub, + need_activation_save, + is_training, + is_grad_enabled, + ) + + tensors, ) # receive result as stub result_stub, tensors, requires_redirection = fut.wait() diff --git a/tests/torch/nn/parallel/pipeline_parallel/test_pp.py b/tests/torch/nn/parallel/pipeline_parallel/test_pp.py index a19cf10a..bbf46d58 100644 --- a/tests/torch/nn/parallel/pipeline_parallel/test_pp.py +++ b/tests/torch/nn/parallel/pipeline_parallel/test_pp.py @@ -15,9 +15,12 @@ from datasets import load_dataset from transformers import ( AutoTokenizer, - GPT2Config, GPT2LMHeadModel, - T5Config, T5ForConditionalGeneration, - BartConfig, BartForConditionalGeneration, + GPT2Config, + GPT2LMHeadModel, + T5Config, + T5ForConditionalGeneration, + BartConfig, + BartForConditionalGeneration, set_seed, ) @@ -56,7 +59,9 @@ def forward( ) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]: use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) # Encode if needed (training, first prediction pass) if encoder_outputs is None: @@ -82,7 +87,11 @@ def forward( if self.model_parallel: torch.cuda.set_device(self.decoder.first_device) - if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None: + if ( + labels is not None + and decoder_input_ids is None + and decoder_inputs_embeds is None + ): # get decoder inputs from shifting lm labels to the right decoder_input_ids = self._shift_right(labels) @@ -95,7 +104,9 @@ def forward( if attention_mask is not None: attention_mask = attention_mask.to(self.decoder.first_device) if decoder_attention_mask is not None: - decoder_attention_mask = decoder_attention_mask.to(self.decoder.first_device) + decoder_attention_mask = decoder_attention_mask.to( + self.decoder.first_device + ) # Decode decoder_outputs = self.decoder( @@ -170,7 +181,7 @@ def forward( model_name = "t5-small" config = T5Config.from_pretrained(model_name) -config.dropout_rate = 0. +config.dropout_rate = 0.0 model = T5ForConditionalGeneration(config) # model = T5Debug(config) @@ -215,7 +226,7 @@ def forward( # if torch.distributed.get_rank() == 1: for k, v in _MODULE_DEVICE_LOCATIONS.items(): - print(f'{k}: {v}') + print(f"{k}: {v}") def run():