Skip to content

Commit

Permalink
Merge pull request #53 from graphcore-research/awf/refactor-2
Browse files Browse the repository at this point in the history
Proposal to refactor a state-carrying closure to a class
  • Loading branch information
thecharlieblake committed Apr 20, 2024
2 parents f5c1e96 + 1c27b6f commit d60dbc3
Showing 1 changed file with 15 additions and 17 deletions.
32 changes: 15 additions & 17 deletions unit_scaling/transforms/_track_scales.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from torch.fx.graph_module import GraphModule
from torch.fx.node import Node, Target

from .utils import Backend, apply_transform
from .utils import apply_transform

logger = logging.getLogger(__name__)
M = TypeVar("M", bound=nn.Module)
Expand Down Expand Up @@ -115,7 +115,7 @@ def new_forward(*args: Any, **kwargs: Any) -> Any:
module.forward = new_forward


class _Track(torch.autograd.Function):
class ScaleTrackingAutogradFunction(torch.autograd.Function):
@staticmethod
def forward( # type:ignore[override]
ctx: torch.autograd.function.FunctionCtx,
Expand Down Expand Up @@ -189,7 +189,7 @@ def _get_tracking_meta(n: Node, out: Any) -> Dict[str, Any]:
}


class _Tracker(Interpreter):
class ScaleTrackingInterpreter(Interpreter):
def __init__(self, gm: GraphModule) -> None:
super().__init__(gm)

Expand All @@ -198,22 +198,23 @@ def run_node(self, n: Node) -> Any:
n.meta.update(_get_tracking_meta(n, out))
if n.meta["outputs_float_tensor"]:
logger.info("adding tracking to node: %s", n)
out = _Track.apply(out, n.meta) # type: ignore
out = ScaleTrackingAutogradFunction.apply(out, n.meta) # type: ignore
return out

def __call__(self, *args: Any, **kwargs: Any) -> Any:
return super().run(*args, **kwargs)


def scale_tracking_backend(graph_holder: List[Graph]) -> Backend:
def inner_backend(
gm: GraphModule, example_inputs: List[Tensor]
class ScaleTrackingBackend:
def __init__(self) -> None:
self.graph = Graph()

def __call__(
self, gm: GraphModule, example_inputs: List[Tensor]
) -> Callable[..., Any]:
_add_tabular_html_display(gm.graph) # displays full info in notebooks
graph_holder[0] = gm.graph # allows graph to be accessed from outside
return _Tracker(gm)

return inner_backend
self.graph = gm.graph # allows graph to be accessed from outside
return ScaleTrackingInterpreter(gm)


def _prune(graph: Graph, node: Node, replacement_arg: Optional[Node] = None) -> None:
Expand Down Expand Up @@ -302,13 +303,10 @@ def track_scales(module: M) -> M:
Returns:
M: a new version of the input module which tracks tensor metrics when used.
"""
graph_holder = [Graph()]
tracking_module = apply_transform(module, scale_tracking_backend(graph_holder))

def scales_graph() -> Graph:
return graph_holder[0] # type: ignore
backend = ScaleTrackingBackend()
tracking_module = apply_transform(module, backend)

tracking_module.scales_graph = scales_graph
tracking_module.scales_graph = lambda: backend.graph
_make_input_tensors_require_grad(tracking_module)
return tracking_module # type: ignore[no-any-return]

Expand Down

0 comments on commit d60dbc3

Please sign in to comment.