Skip to content

Commit

Permalink
Merge pull request #52 from graphcore-research/awf/pt23
Browse files Browse the repository at this point in the history
Update to PyTorch 2.2
  • Loading branch information
awf committed Apr 22, 2024
2 parents d60dbc3 + 600fe18 commit 2394ee4
Show file tree
Hide file tree
Showing 7 changed files with 98 additions and 96 deletions.
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@ einops
numpy
seaborn
tabulate
torch==2.1
torch>=2.2
5 changes: 3 additions & 2 deletions unit_scaling/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,8 +282,6 @@ def _rename(s: str) -> str:
s = s.replace("transformer_", "")
return s

p.set_yticklabels([_rename(item.get_text()) for item in p.get_yticklabels()])

plt.axvline(2**-14, color="grey", dashes=(3, 1))
plt.axvline(2**-7, color="grey", dashes=(1, 3))
plt.axvline(240, color="grey", dashes=(1, 3))
Expand Down Expand Up @@ -438,6 +436,9 @@ def draw_arrow(node_a: Node, node_b: Node, direction: str) -> None:
for direction in ["fwd", "bwd"]:
draw_error_bar(n, direction)

p.set_yticks(p.get_yticks())
p.set_yticklabels([_rename(item.get_text()) for item in p.get_yticklabels()])

return p # type: ignore[no-any-return]


Expand Down
19 changes: 10 additions & 9 deletions unit_scaling/docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,20 +41,21 @@ def _validate(
raise ValueError(f"unsupported arg '{arg}' has no default value")

@wraps(f)
def f_new(*args: Any, **kwargs: Any) -> T:
def _validate_args_supported(*args: Any, **kwargs: Any) -> T:
arg_values = dict(zip(argspec.args, args))
full_kwargs = {**arg_values, **kwargs}
for arg_name, arg_value in full_kwargs.items():
arg_default_value = default_kwargs[arg_name]
if arg_name in unsupported_args and arg_value != arg_default_value:
raise ValueError(
f"Support for the '{arg_name}' argument has not been implemented"
" for the unit-scaling library. Please remove it or replace it"
" with its default value."
)
if arg_name in unsupported_args:
arg_default_value = default_kwargs[arg_name]
if arg_value != arg_default_value:
raise ValueError(
f"Support for the '{arg_name}' argument has not been"
" implemented for the unit-scaling library."
" Please remove it or replace it with its default value."
)
return f(*args, **kwargs)

return f_new
return _validate_args_supported


def _get_docstring_from_target(
Expand Down
8 changes: 4 additions & 4 deletions unit_scaling/tests/test_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,14 +82,14 @@ def forward(self, x: Tensor) -> Tensor: # pragma: no covers
"layer": [
"x",
"x",
"relu",
"relu",
"y",
"y",
"linear_weight",
"linear_weight",
"linear_bias",
"linear_bias",
"linear",
"linear",
"z",
"z",
"sum_1",
"sum_1",
],
Expand Down
18 changes: 11 additions & 7 deletions unit_scaling/tests/transforms/test_track_scales.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,18 @@
)


def get_target(node: Node) -> Union[str, Callable]: # type: ignore[type-arg]
def get_target_or_node_name(node: Node) -> Union[str, Callable[..., Any]]:
return node.meta["clean_name"] if isinstance(node.target, str) else node.target


def get_targets(graph: Graph) -> Set[Union[str, Callable]]: # type: ignore[type-arg]
return set(get_target(node) for node in graph.nodes)
return set(get_target_or_node_name(node) for node in graph.nodes)


def get_target_map(
graph: Graph,
) -> Dict[Union[str, Callable], Dict[str, Any]]: # type: ignore[type-arg]
return {get_target(node): node.meta for node in graph.nodes}
return {get_target_or_node_name(node): node.meta for node in graph.nodes}


def test_track_scales() -> None:
Expand Down Expand Up @@ -181,12 +181,16 @@ def forward(self, idxs: Tensor) -> Tuple[Tensor, Tensor]: # pragma: no cover
model(idxs)

graph = model.scales_graph()

# Version-dependent, see https://github.com/graphcore-research/unit-scaling/pull/52
var_lhs_flatten = "x"
var_lhs_view = "x_1"
expected_targets = {
"idxs",
"emb_weight",
F.embedding,
"flatten",
"view",
var_lhs_flatten,
var_lhs_view,
"linear_weight",
"linear_bias",
F.linear,
Expand All @@ -203,7 +207,7 @@ def forward(self, idxs: Tensor) -> Tuple[Tensor, Tensor]: # pragma: no cover

graph = prune_same_scale_tensors(graph)
graph_targets = get_targets(graph)
expected_targets -= {"flatten", "view"}
expected_targets -= {var_lhs_flatten, var_lhs_view}
assert graph_targets == expected_targets

graph = prune_same_scale_tensors(graph, rtol=2**-4)
Expand Down Expand Up @@ -234,7 +238,7 @@ def forward(self, a: Tensor) -> Tensor: # pragma: no cover
operator.mul,
F.relu,
operator.sub,
"sum_1",
"f",
"output",
}
graph_targets = get_targets(graph)
Expand Down
13 changes: 7 additions & 6 deletions unit_scaling/tests/transforms/test_unit_scale.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import logging
import math
import re
from typing import Tuple

import torch
Expand Down Expand Up @@ -94,14 +95,14 @@ def forward(self, input: Tensor) -> Tuple[Tensor, Tensor]: # pragma: no cover
loss.backward()

expected_logs = [
"unit scaling function: add\n",
"unit scaling function: iadd\n",
"unit scaling function: iadd_1 (residual-add, tau=0.5)",
"unit scaling function: add_1 (residual-add, tau=0.5)",
r"unit scaling function: (input_2)\n",
r"unit scaling function: (input_4)\n",
r"unit scaling function: (skip_1|input_3) \(residual-add, tau=0\.5\)",
r"unit scaling function: (add_1|input_6) \(residual-add, tau=0\.5\)",
]
print(caplog.text)

for log_msg in expected_logs:
assert log_msg in caplog.text
assert re.search(log_msg, caplog.text)


def test_fp8_unit_scaling(caplog: LogCaptureFixture) -> None:
Expand Down
129 changes: 62 additions & 67 deletions unit_scaling/transforms/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,22 @@

"""Utilities for working with transforms."""

import copy
import functools
from contextlib import contextmanager
from copy import copy, deepcopy
from typing import (
Any,
Callable,
Dict,
Iterable,
List,
Optional,
Set,
Tuple,
TypeVar,
no_type_check,
)
from unittest.mock import patch

import torch
import torch._dynamo
from torch import Tensor, nn
from torch.fx.graph import Graph
Expand All @@ -35,58 +34,41 @@
_unit_scaled_functions = [getattr(U, f) for f in U.__all__]


def _get_patched_allowed_function_ids(
non_recurse_functions: Iterable[Callable[..., Any]],
) -> Set[int]:
allowed_function_ids = copy(torch._dynamo.allowed_functions._allowed_function_ids)
for v in nn.modules.__dict__.values():
if isinstance(v, type) and v not in nn.modules.loss.__dict__.values():
i = id(v)
if i in allowed_function_ids:
allowed_function_ids.remove(i)
for f in non_recurse_functions:
allowed_function_ids.add(id(f))
return allowed_function_ids # type: ignore[no-any-return]


def _patched_call_function( # type: ignore[no-untyped-def]
self,
tx,
args,
kwargs,
): # pragma: no cover
if isinstance(self.obj, torch._dynamo.variables.NNModuleVariable):
module_attr = getattr(self.fn, "__module__", "")
if (
module_attr is not None
and module_attr.startswith("torch.nn.modules.module")
or self.is_constant
):
return self.obj.call_method( # type: ignore[no-untyped-call]
tx, self.fn.__name__, args, kwargs, constant=self.is_constant
).add_options(self)
return super(
torch._dynamo.variables.functions.UserMethodVariable, self
).call_function(tx, args, kwargs)


@contextmanager
def _expand_modules_patch(non_recurse_functions): # type: ignore[no-untyped-def]
patcher_a = patch(
"torch._dynamo.allowed_functions._allowed_function_ids",
new=_get_patched_allowed_function_ids(non_recurse_functions),
)
patcher_b = patch(
"torch._dynamo.variables.functions.UserMethodVariable.call_function",
new=_patched_call_function,
)
with patcher_a, patcher_b:
yield (patcher_a.start(), patcher_b.start())


def patch_to_expand_modules(
fn: Callable[..., T], non_recurse_functions: Iterable[Callable[..., Any]] = ()
) -> Callable[..., T]:
def torch_nn_modules_to_user_modules(mod: nn.Module) -> Any:
"""
Convert torch.nn.module classes to `trivial_subclass` versions.
By default TorchDynamo doesn't recurse into :mod:`torch.nn` modules or
:mod:`torch.nn.functional` functions when capturing the FX graph.
This function makes `torch.nn` modules into user modules.
To use this with a :class:`torch.nn.Module` the typical use case
is to call `module = torch_nn_modules_to_user_modules(module)`.
"""

for n, submod in mod.named_modules():
# Mirroring the check in https://github.com/pytorch/pytorch/blob/72662bf05b3499ce96aae9183a489c78f0c44c84/torch/_dynamo/variables/functions.py#L335 # noqa: E501
if submod.__module__.startswith("torch.nn."):
# Generate a new name, so e.g. torch.nn.modules.sparse.Embedding
# becomes trivial_subclass_modules_sparse_Embedding
modulename = submod.__module__
modulename = modulename.replace("torch.nn.", "", 1)
modulename = modulename.replace(".", "_")
newtypename = "trivial_subclass_" + modulename + "_" + type(submod).__name__

# Create a new type object deriving from type(submod)
newmodtype = type(newtypename, (type(submod),), {})

# Initialize and copy state using pickle
newsubmod = newmodtype.__new__(newmodtype) # type: ignore [call-overload]
newsubmod.__setstate__(submod.__getstate__())

# Update module in mod
setattr(mod, n, newsubmod)


def patch_to_expand_modules(fn: Callable[..., T]) -> Callable[..., T]:
"""By default TorchDynamo doesn't recurse into :mod:`torch.nn` modules or
:mod:`torch.nn.functional` functions when capturing the FX graph.
Any function which is wrapped in
Expand All @@ -98,21 +80,32 @@ def patch_to_expand_modules(
is to call `module = torch._dynamo.optimize(backend)(module)`, followed by
`module.forward = patch_to_expand_modules(module.forward)`.
In addition, any functions the user *doesn't* wish to recurse into can be passed
into `non_recurse_functions` and these will not be expanded.
This should be used in conjunction with :func:`torch_nn_modules_to_user_modules`
Args:
fn (Callable[..., T]): the function to be patched.
non_recurse_functions (Iterable[Callable[..., Any]], optional): functions which
the user does not wish to be recursed into. Defaults to ().
Returns:
Callable[..., T]: the new version of `fn` with patching applied.
"""

def _patched_call_function( # type: ignore[no-untyped-def]
self,
tx,
args,
kwargs,
): # pragma: no cover
# Removing the check in https://github.com/pytorch/pytorch/blob/72662bf05b3499ce96aae9183a489c78f0c44c84/torch/_dynamo/variables/functions.py#L335 # noqa: E501
return super(
torch._dynamo.variables.functions.UserMethodVariable, self
).call_function(tx, args, kwargs)

@functools.wraps(fn)
def new_fn(*args: Any, **kwargs: Any) -> Any:
with _expand_modules_patch(non_recurse_functions):
with patch(
"torch._dynamo.variables.functions.UserMethodVariable.call_function",
new=_patched_call_function,
):
return fn(*args, **kwargs)

return new_fn
Expand Down Expand Up @@ -211,22 +204,24 @@ def apply_transform(
Returns:
nn.Module: the transformed module.
"""
module = deepcopy(module)
module = copy.deepcopy(module)

torch_nn_modules_to_user_modules(module)

if not hasattr(module, "backends"):
module.backends = []
module.backends.append(backend)
if not hasattr(module, "non_recurse_functions"):
module.non_recurse_functions = list(_unit_scaled_functions)
module.non_recurse_functions += non_recurse_functions

for v in non_recurse_functions:
torch._dynamo.allow_in_graph(v)

backend = _compose_backends(module.backends)

def new_forward(*args: Any, **kwargs: Any) -> Any:
if module.rerun_transform:
torch._dynamo.reset()
dynamo_module = torch._dynamo.optimize(backend)(module)
module.dynamo_forward = patch_to_expand_modules(
dynamo_module.forward, module.non_recurse_functions
)
module.dynamo_forward = patch_to_expand_modules(dynamo_module.forward)
module.rerun_transform = False
with patch.object(module, "forward", module.base_forward):
return module.dynamo_forward(*args, **kwargs)
Expand Down

0 comments on commit 2394ee4

Please sign in to comment.