Skip to content

Commit

Permalink
pre-commit
Browse files Browse the repository at this point in the history
  • Loading branch information
ohwi committed Sep 30, 2022
1 parent 8d80aa7 commit 8ab949d
Show file tree
Hide file tree
Showing 6 changed files with 68 additions and 31 deletions.
6 changes: 4 additions & 2 deletions oslo/torch/nn/parallel/pipeline_parallel/_buffers.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from oslo.torch.nn.parallel.pipeline_parallel._sync import register_location_for_forward_counter
from oslo.torch.nn.parallel.pipeline_parallel._sync import (
register_location_for_forward_counter,
)


# original forward dictionary
Expand Down Expand Up @@ -31,4 +33,4 @@ def save_activation(key, activation):


def pop_activation(key):
return _ACTIVATIONS.pop(key, []) # TODO; okay?
return _ACTIVATIONS.pop(key, []) # TODO; okay?
16 changes: 12 additions & 4 deletions oslo/torch/nn/parallel/pipeline_parallel/_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,18 @@
register_job_requires_backward,
notify_backward_job_done,
)
from oslo.torch.nn.parallel.pipeline_parallel._messages import pack_tensor_stub, unpack_tensor_stub
from oslo.torch.nn.parallel.pipeline_parallel._messages import (
pack_tensor_stub,
unpack_tensor_stub,
)


def remote_module_forward(
caller, location, unique_key,
args_stub, kwargs_stub,
caller,
location,
unique_key,
args_stub,
kwargs_stub,
requires_redirection,
is_training,
is_grad_enabled,
Expand All @@ -37,7 +43,9 @@ def remote_module_forward(
result = forward_fn(*args, **kwargs)

result_stub, tensors = pack_tensor_stub(result, [])
need_activation_save = any([t.requires_grad for t in tensors]) and is_training and is_grad_enabled
need_activation_save = (
any([t.requires_grad for t in tensors]) and is_training and is_grad_enabled
)
if need_activation_save:
save_activation(unique_key, tensors)

Expand Down
14 changes: 9 additions & 5 deletions oslo/torch/nn/parallel/pipeline_parallel/_messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
import torch

from oslo.torch.nn.parallel.pipeline_parallel._utils import (
_is_namedtuple, _is_private, _is_primitive
_is_namedtuple,
_is_private,
_is_primitive,
)


Expand All @@ -30,7 +32,7 @@ def pack_tensor_stub(obj, args_list):
for i in range(len(obj_list)):
obj_list_i, args_list = pack_tensor_stub(obj_list[i], args_list)
obj_list_i[i] = obj_list_i
obj = obj.__class__._make(obj_list) # use namedtuple's method
obj = obj.__class__._make(obj_list) # use namedtuple's method

return obj, args_list

Expand Down Expand Up @@ -60,9 +62,10 @@ def pack_tensor_stub(obj, args_list):
elif _is_primitive(obj):
return obj, args_list

else: # other kinds of object
else: # other kinds of object
members = [
attr for attr in dir(obj)
attr
for attr in dir(obj)
if not callable(getattr(obj, attr)) and not _is_private(attr)
]
for m in members:
Expand Down Expand Up @@ -120,7 +123,8 @@ def unpack_tensor_stub(obj, args_list):

else: # other kinds of object
members = [
attr for attr in dir(obj)
attr
for attr in dir(obj)
if not callable(getattr(obj, attr)) and not _is_private(attr)
]
for m in members:
Expand Down
4 changes: 2 additions & 2 deletions oslo/torch/nn/parallel/pipeline_parallel/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ def _is_namedtuple(obj):


def _is_primitive(obj):
return not hasattr(obj, '__dict__')
return not hasattr(obj, "__dict__")


def _is_private(attr):
return attr.startswith('__')
return attr.startswith("__")
32 changes: 22 additions & 10 deletions oslo/torch/nn/parallel/pipeline_parallel/pipeline_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,21 @@
get_module_device_location,
save_activation,
)
from oslo.torch.nn.parallel.pipeline_parallel._functional import remote_module_forward, apply_backward_redirection
from oslo.torch.nn.parallel.pipeline_parallel._functional import (
remote_module_forward,
apply_backward_redirection,
)
from oslo.torch.nn.parallel.pipeline_parallel._sync import (
wait_other_ranks,
make_unique_key,
reset_forward_used_counter,
set_result,
get_result,
)
from oslo.torch.nn.parallel.pipeline_parallel._messages import pack_tensor_stub, unpack_tensor_stub
from oslo.torch.nn.parallel.pipeline_parallel._messages import (
pack_tensor_stub,
unpack_tensor_stub,
)
from oslo.torch.nn.parallel.pipeline_parallel._model_partitioner import ModelPartitioner


Expand All @@ -37,7 +43,7 @@ def PipelineParallel(
module=module,
parallel_context=parallel_context,
memory_computation_balance=memory_computation_balance,
num_micro_batches=num_micro_batches
num_micro_batches=num_micro_batches,
)


Expand Down Expand Up @@ -139,7 +145,9 @@ def forward(self, *args, **kwargs):

is_grad_enabled = torch.is_grad_enabled()
for ind, (args_, kwargs_) in enumerate(zip(new_args, new_kwargs)):
future = self.producer.submit(launch, self.module, is_grad_enabled, *args_, **kwargs_)
future = self.producer.submit(
launch, self.module, is_grad_enabled, *args_, **kwargs_
)
futures.append(future)

for i, done in enumerate(concurrent.futures.as_completed(futures)):
Expand Down Expand Up @@ -234,12 +242,16 @@ def new_forward(*args, **kwargs):
to=callee,
func=remote_module_forward,
args=(
caller, location, unique_key,
args_stub, kwargs_stub,
need_activation_save,
is_training,
is_grad_enabled,
) + tensors,
caller,
location,
unique_key,
args_stub,
kwargs_stub,
need_activation_save,
is_training,
is_grad_enabled,
)
+ tensors,
)
# receive result as stub
result_stub, tensors, requires_redirection = fut.wait()
Expand Down
27 changes: 19 additions & 8 deletions tests/torch/nn/parallel/pipeline_parallel/test_pp.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,12 @@
from datasets import load_dataset
from transformers import (
AutoTokenizer,
GPT2Config, GPT2LMHeadModel,
T5Config, T5ForConditionalGeneration,
BartConfig, BartForConditionalGeneration,
GPT2Config,
GPT2LMHeadModel,
T5Config,
T5ForConditionalGeneration,
BartConfig,
BartForConditionalGeneration,
set_seed,
)

Expand Down Expand Up @@ -56,7 +59,9 @@ def forward(
) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]:

use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)

# Encode if needed (training, first prediction pass)
if encoder_outputs is None:
Expand All @@ -82,7 +87,11 @@ def forward(
if self.model_parallel:
torch.cuda.set_device(self.decoder.first_device)

if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None:
if (
labels is not None
and decoder_input_ids is None
and decoder_inputs_embeds is None
):
# get decoder inputs from shifting lm labels to the right
decoder_input_ids = self._shift_right(labels)

Expand All @@ -95,7 +104,9 @@ def forward(
if attention_mask is not None:
attention_mask = attention_mask.to(self.decoder.first_device)
if decoder_attention_mask is not None:
decoder_attention_mask = decoder_attention_mask.to(self.decoder.first_device)
decoder_attention_mask = decoder_attention_mask.to(
self.decoder.first_device
)

# Decode
decoder_outputs = self.decoder(
Expand Down Expand Up @@ -170,7 +181,7 @@ def forward(

model_name = "t5-small"
config = T5Config.from_pretrained(model_name)
config.dropout_rate = 0.
config.dropout_rate = 0.0
model = T5ForConditionalGeneration(config)
# model = T5Debug(config)

Expand Down Expand Up @@ -215,7 +226,7 @@ def forward(
#
if torch.distributed.get_rank() == 1:
for k, v in _MODULE_DEVICE_LOCATIONS.items():
print(f'{k}: {v}')
print(f"{k}: {v}")


def run():
Expand Down

0 comments on commit 8ab949d

Please sign in to comment.