Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CodeBlockLogic as default logic #618

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions pytorch_pfn_extras/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def create_trainer(
Device name used for selecting a corresponding runtime class.
logic:
A logic object. If `None` is given, an logic object is instantiated
from the default logic class.
from the :class:`pytorch_pfn_extras.handler.CodeBlockLogic` class.
transform_model:
A function to transform a model structure, often used to unwrap the
a module from DDP module.
Expand All @@ -96,7 +96,7 @@ def create_trainer(
runtime_options = dict(
runtime_options if runtime_options
else options.pop('runtime', {}))
logic = handler_module.Logic() if logic is None else logic
logic = handler_module.CodeBlockLogic() if logic is None else logic
handler_class = handler_class if handler_class else handler_module.Handler

entry_runtime_cls = runtime_registry.get_runtime_class_for_device_spec(
Expand Down Expand Up @@ -128,7 +128,7 @@ def create_evaluator(
progress_bar: bool = False,
device: 'DeviceLike' = 'cpu',
metrics: Optional[Sequence['MetricType']] = None,
logic: Optional[handler_module.Logic] = None,
logic: Optional[handler_module.BaseLogic] = None,
handler_class: Optional[Type[handler_module.Handler]] = None,
options: Optional[Dict[str, Any]] = None,
runtime_options: Optional[Mapping[str, Any]] = None,
Expand All @@ -151,7 +151,7 @@ def create_evaluator(
output for the reporting.
logic:
A logic object. If `None` is given, an logic object is instantiated
from the default logic class.
from the :class:`pytorch_pfn_extras.handler.CodeBlockLogic` class.
handler_class:
A handler class that instantiates a handler object. If `None` is
given, `ppe.handler.Handler` is used as a default handler class.
Expand All @@ -173,7 +173,7 @@ def create_evaluator(
runtime_options = dict(
runtime_options if runtime_options
else options.pop('runtime', {}))
logic = handler_module.Logic() if logic is None else logic
logic = handler_module.CodeBlockLogic() if logic is None else logic
handler_class = handler_class if handler_class else handler_module.Handler

entry_runtime_cls = runtime_registry.get_runtime_class_for_device_spec(
Expand Down
34 changes: 32 additions & 2 deletions pytorch_pfn_extras/handler/_code_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Any, Callable, Dict, List, Optional, Set

import torch
import pytorch_pfn_extras as ppe


@dataclass
Expand All @@ -26,6 +27,7 @@ class CodeBlock:
backprop: bool
backprop_from: Optional[str]
backprop_to: Optional[Set[str]]
backprop_fn : Optional[Callable[..., Any]]
state: Dict[str, Any]
runtime: Any

Expand Down Expand Up @@ -56,6 +58,7 @@ def update_parameters(
optimizers: List[torch.optim.Optimizer],
backprop_from: Optional[str] = None,
backprop_to: Optional[Set[str]] = None,
backprop_fn : Optional[Callable[..., Any]] = None,
) -> CodeBlock:
"""
Returns a ``CodeBlock`` that performs the forward, backward passes and
Expand All @@ -80,6 +83,7 @@ def update_parameters(
backprop=True,
backprop_from=backprop_from,
backprop_to=backprop_to,
backprop_fn=backprop_fn,
state=codeblock.state,
runtime=codeblock.runtime,
)
Expand All @@ -106,9 +110,34 @@ def forward(block: Callable) -> CodeBlock:
else:
module = getattr(block, '__self__', None)
assert module is not None
func = block
runtime = ppe.runtime._runtime._module_runtime_tag(module)

def _forward(batch: Any) -> Any:

def _normalize_outputs(outputs: Any) -> Dict[str, Any]:
target: Dict[str, Any]
if isinstance(outputs, tuple) and hasattr(outputs, '_fields'):
# namedtuple
target = outputs._asdict() # type: ignore[attr-defined]
elif isinstance(outputs, dict):
target = outputs
elif isinstance(outputs, (list, tuple)):
target = {str(i): out for i, out in enumerate(outputs)}
else:
target = {"0": outputs}
return target

if isinstance(batch, tuple) and hasattr(batch, '_fields'):
# namedtuple
return _normalize_outputs(block(batch))
if isinstance(batch, dict):
return _normalize_outputs(block(**batch))
if isinstance(batch, (list, tuple)):
return _normalize_outputs(block(*batch))
return _normalize_outputs(block(batch))

func = _forward
state = {}
runtime = getattr(module, '_ppe_runtime', None)
assert runtime is not None

return CodeBlock(
Expand All @@ -117,6 +146,7 @@ def forward(block: Callable) -> CodeBlock:
backprop=False,
backprop_from=None,
backprop_to=None,
backprop_fn=None,
state=state,
runtime=runtime,
)
2 changes: 2 additions & 0 deletions pytorch_pfn_extras/handler/_logic.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,6 +383,7 @@ def consume_options(self, options: Dict[str, Any]) -> None:
self.backward_outputs = options.pop('backward_outputs', None)
if self.backward_outputs is not None:
assert isinstance(self.backward_outputs, str)
self._backward_fn = options.pop('backward_function', None)

def train_epoch_begin(
self,
Expand Down Expand Up @@ -433,6 +434,7 @@ def train_step(
list(optimizers.values()),
self.backward_outputs,
None,
self._backward_fn,
)(batch)

def train_validation_begin(
Expand Down
12 changes: 9 additions & 3 deletions pytorch_pfn_extras/runtime/_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,9 +432,10 @@ def _scale(x: torch.Tensor) -> torch.Tensor:

# with autocast
with _autocast(enabled=self._autocast):
out = code_block.func(**batch)
out = code_block.func(batch)

# codeblocks return Dicts-per-se so it is not necessary to normalize
to_backprop = []
if code_block.backprop:
if code_block.backprop_from is None:
for v in out.values():
Expand All @@ -447,10 +448,15 @@ def _scale(x: torch.Tensor) -> torch.Tensor:
or v.dtype.is_complex
)
):
_scale(v).backward() # type: ignore[no-untyped-call]
to_backprop.append(_scale(v))
else:
_scale(out[code_block.backprop_from]).backward() # type: ignore
to_backprop.append(_scale(out[code_block.backprop_from]))

for v in to_backprop:
if code_block.backprop_fn is not None:
code_block.backprop_fn(v) # type: ignore
else:
v.backward() # type: ignore[no-untyped-call]
if len(code_block.optimizers) == 0:
return out

Expand Down
2 changes: 1 addition & 1 deletion tests/pytorch_pfn_extras_tests/test_logic.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@ def test_trainer(device):
iters_per_epoch = 10
epochs = 20
model = MyModel()
ppe.to(model, device)
model_with_loss = MyModelWithLossFn(model)
ppe.to(model_with_loss, device)
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
data = torch.utils.data.DataLoader(
[(torch.rand(20,), torch.rand(10,)) for i in range(iters_per_epoch)])
Expand Down
20 changes: 10 additions & 10 deletions tests/pytorch_pfn_extras_tests/training_tests/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,8 @@ def test_trainer(device, path):
if not torch.cuda.is_available() and device == 'cuda':
pytest.skip()
model = MyModel()
ppe.to(model, device)
model_with_loss = MyModelWithLossFn(model)
ppe.to(model_with_loss, device)
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
data = torch.utils.data.DataLoader(
[(torch.rand(20,), torch.rand(10,)) for i in range(10)])
Expand Down Expand Up @@ -109,8 +109,8 @@ def test_trainer_no_to(path):
def test_trainer_invalid_options(path):
device = 'cpu'
model = MyModel()
ppe.to(model, device)
model_with_loss = MyModelWithLossFn(model)
ppe.to(model_with_loss, device)
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
extensions = _make_extensions()
options = {'UNKNOWN_OPTIONS': True}
Expand All @@ -129,8 +129,8 @@ def test_train_with_evaluator(device, progress_bar, path):
if not torch.cuda.is_available() and device == 'cuda':
pytest.skip()
model = MyModel()
ppe.to(model, device)
model_with_loss = MyModelWithLossFn(model)
ppe.to(model_with_loss, device)
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
data = torch.utils.data.DataLoader(
[(torch.rand(20,), torch.rand(10,)) for i in range(10)])
Expand Down Expand Up @@ -158,8 +158,8 @@ def test_evaluator_trigger(evaluator_trigger, path):
device = 'cpu'
progress_bar = False
model = MyModel()
ppe.to(model, device)
model_with_loss = MyModelWithLossFn(model)
ppe.to(model_with_loss, device)
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
data = torch.utils.data.DataLoader(
[(torch.rand(20,), torch.rand(10,)) for i in range(10)])
Expand All @@ -183,8 +183,8 @@ def test_evaluator_dict(path):
device = 'cpu'
progress_bar = False
model = MyModel()
ppe.to(model, device)
model_with_loss = MyModelWithLossFn(model)
ppe.to(model_with_loss, device)
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
data = torch.utils.data.DataLoader(
[(torch.rand(20,), torch.rand(10,)) for i in range(10)])
Expand Down Expand Up @@ -220,8 +220,8 @@ def test_train_result_equal(device, path):

def get_result_from_trainer():
model = MyModel()
ppe.to(model, device)
model_with_loss = MyModelWithLossFn(model)
ppe.to(model_with_loss, device)
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
extensions = _make_extensions()

Expand All @@ -238,8 +238,8 @@ def get_result_from_trainer():

def get_result_from_training_loop():
model = MyModel()
ppe.to(model, device)
model_with_loss = MyModelWithLossFn(model)
ppe.to(model_with_loss, device)
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)

model_with_loss.train()
Expand Down Expand Up @@ -293,8 +293,8 @@ def _compare_states(s1, s2):
class TestTrainerState:
def _get_trainer(self, epochs, out_dir):
model = MyModel()
ppe.to(model, 'cpu')
model_with_loss = MyModelWithLossFn(model)
ppe.to(model_with_loss, 'cpu')
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
extensions = _make_extensions()
trainer = engine.create_trainer(
Expand Down Expand Up @@ -356,8 +356,8 @@ def test_trainer_dict_input(device, progress_bar, path):
if not torch.cuda.is_available() and device == 'cuda':
pytest.skip()
model = MyModel()
ppe.to(model, device)
model_with_loss = MyModelWithLossDictOutput(model)
ppe.to(model_with_loss, device)
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
data = torch.utils.data.DataLoader(
[{'x': torch.rand(20,), 't': torch.rand(10,)} for i in range(10)])
Expand Down Expand Up @@ -405,8 +405,8 @@ def test_trainer_namedtuple_input(device, progress_bar, path):
if not torch.cuda.is_available() and device == 'cuda':
pytest.skip()
model = MyModel()
ppe.to(model, device)
model_with_loss = ModelNamedTupleIO(model)
ppe.to(model_with_loss, device)
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
data = torch.utils.data.DataLoader(
[Input(torch.rand(20,), torch.rand(10,), str(i)) for i in range(10)])
Expand Down