Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP,POC] Faster functional modules #983

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
239 changes: 143 additions & 96 deletions functorch/_src/make_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import torch.nn as nn
from torch import Tensor
from typing import List, Tuple
from .named_members_polyfill import _named_parameters, _named_buffers
import copy

# Utilities to make nn.Module "functional"
Expand Down Expand Up @@ -56,66 +55,12 @@ def raise_parameter_tying_error():
"https://github.com/pytorch/functorch/issues/446")


def create_names_map(named_params, tied_named_params):
"""
named_params is a dictionary of tensors: {'A': A, 'B': B}
tied_named_params is another dictionary of tensors {'A': A, 'B': B, 'B_tied': B}
with potentially tied (or 'duplicated') tensors

This function creates a mapping from the names in named_params to the
names in tied_named_params: {'A': ['A'], 'B': ['B', 'B_tied']}.
"""
named_params = {k: v for k, v in named_params}
tied_named_params = {k: v for k, v in tied_named_params}

tensors_dict_keys = set(named_params.keys())
tied_tensors_dict_keys = set(tied_named_params.keys())
assert tensors_dict_keys.issubset(tied_tensors_dict_keys)

tensor_to_mapping = {}
for key, tensor in named_params.items():
tensor_to_mapping[tensor] = (key, [])
for key, tensor in tied_named_params.items():
assert tensor in tensor_to_mapping
tensor_to_mapping[tensor][1].append(key.split('.'))
result = {key: value for key, value in tensor_to_mapping.values()}
return result


def _extract_members(mod: nn.Module, _named_members, named_members, subclass):
all_named_members = tuple(_named_members(mod, remove_duplicate=False))
named_members = tuple(named_members())
names_map = create_names_map(named_members, all_named_members)

# Remove all the members in the model
memo = {}
for name, p in all_named_members:
if p not in memo:
memo[p] = subclass(torch.empty_like(p, device='meta'))
replacement = memo[p]
_set_nested_attr(mod, name.split("."), replacement)

if len(named_members) == 0:
names, params = (), ()
else:
names, params = zip(*named_members)
return params, names, names_map


def extract_weights(mod: nn.Module):
"""
This function removes all the Parameters from the model and
return them as a tuple as well as their original attribute names.
The weights must be re-loaded with `load_weights` before the model
can be used again.
Note that this function modifies the model in place and after this
call, mod.parameters() will be empty.
"""
return _extract_members(mod, _named_parameters, mod.named_parameters, nn.Parameter)


def extract_buffers(mod: nn.Module):
return _extract_members(mod, _named_buffers, mod.named_buffers, lambda x: x)
def extract_weights(model):
for module_name, m in model.named_modules():
for param_name, p in list(m.named_parameters(recurse=False)):
delattr(m, param_name)
setattr(m, param_name, None)
yield (module_name, m, param_name, p)
Comment on lines +59 to +63
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason we previously used create_names_map was for parameter tying. If someone creates a module that looks like:

class Foo(nn.Module):
   def __init__(self):
        super().__init__()
        self.bias = nn.Parameter(torch.randn(3))
        self.linear = nn.Linear(3, 3)
        self.linear.bias = self.bias

then fmodel, params = make_functional(Foo()) returns 2 Tensors (self.linear.weight and self.bias) instead of 3 Tensors. When the user calls fmodel([w, b], x), then b gets loaded to self.bias and self.linear.bias and w gets loaded to self.linear.weight.

Under the new strategy, it seems like params would have 3 tensors: [self.bias, self.linear.weight, self.linear.bias].

In general I'm not really sure what the interaction between parameter tying and make_functional should be. Thoughts?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see your point. If we want to keep things as they are we could link params to a list of modules and a list of names (instead of a single module and a single name). That will come with a slight overhead though...

It's the kind of design choice where you will always make someone unhappy (there will be someone out there that wants multiple copies of the same param), but it's probably not the majority of users.



def load_weights(mod: nn.Module, names: List[str], params: Tuple[Tensor, ...], as_params=False) -> None:
Expand All @@ -131,15 +76,21 @@ def load_weights(mod: nn.Module, names: List[str], params: Tuple[Tensor, ...], a
_set_nested_attr(mod, name.split("."), p)


def _swap_state(mod: nn.Module, names_map: List[str], elems):
result = []
for (_, attr_names), elem in zip(names_map.items(), elems):
for i, attr_name in enumerate(attr_names):
if i == 0:
result.append(_get_nested_attr(mod, attr_name))
_del_nested_attr(mod, attr_name)
_set_nested_attr(mod, attr_name, elem)
return result
def _swap_state(param_modules, param_names, params):
old_params = []
for module, param_name, param in zip(param_modules, param_names, params):
old_params.append(getattr(module, param_name))
delattr(module, param_name)
setattr(module, param_name, param)
return old_params


def extract_buffers(model):
for module_name, m in model.named_modules():
for buffer_name, b in list(m.named_buffers(recurse=False)):
delattr(m, buffer_name)
setattr(m, buffer_name, None)
yield (module_name, m, buffer_name, b)


def load_buffers(mod: nn.Module, names: List[str], buffers: Tuple[Tensor, ...], as_params=False) -> None:
Expand Down Expand Up @@ -194,7 +145,7 @@ def make_functional_deprecated_v1(model: nn.Module):
if len(buffers) > 0:
raise RuntimeError('make_functional_deprecated_v1(model): `model` has buffers. Please use '
'make_functional_with_buffers_deprecated_v1(model) instead.')
weights, descriptors, _ = extract_weights(model)
weights, descriptors = extract_weights(model)

def fun(weights, data):
mutable_model = copy.deepcopy(model)
Expand Down Expand Up @@ -241,79 +192,175 @@ def fun(weights, buffers, data):
return weights, buffers, fun, weight_descriptors, buf_descriptors


def make_split_names(lst):
return [name.split('.') for name in lst]


class FunctionalModuleWithBuffers(nn.Module):
"""
This is the callable object returned by :func:`make_functional_with_buffers`.
"""

def __init__(self, stateless_model, param_names, buffer_names,
param_names_map, buffer_names_map):
def __init__(
self,
stateless_model,
param_module_names,
param_modules,
param_names,
buffer_module_names,
buffer_modules,
buffer_names
):
super(FunctionalModuleWithBuffers, self).__init__()
self.stateless_model = stateless_model
self.param_module_names = param_module_names
self.param_modules = param_modules
self.param_names = param_names
self.buffer_module_names = buffer_module_names
self.buffer_modules = buffer_modules
self.buffer_names = buffer_names

self.all_names_map = dict(param_names_map)
self.all_names_map.update(buffer_names_map)

@staticmethod
def _create_from(model, disable_autograd_tracking=False):
def _create_from(model):
# TODO: We don't need to copy the model to create a stateless copy
model_copy = copy.deepcopy(model)
params, param_names, param_names_map = extract_weights(model_copy)
buffers, buffer_names, buffer_names_map = extract_buffers(model_copy)
param_container = list(extract_weights(model_copy))
if len(param_container):
param_module_names, param_modules, param_names, params = zip(*param_container)
else:
param_module_names, param_modules, param_names, params = tuple(), tuple(), tuple(), tuple()
Comment on lines +229 to +231
Copy link
Contributor

@zou3519 zou3519 Jul 26, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Previously, we guaranteed that params is returned in the same order as what gets returned by original_model.parameters(). After this change, is that still true?

(Side note) To be honest, we've been thinking of changing the API so that params isn't returned as a flat list; instead we probably want to return some sort of dictionary or object so that one can easily figure out which params corresponds to which parameters on the original module. This is something that a couple of users have asked us for. If we returned a dictionary then it doesn't matter that params isn't the same as what gets returned by original_module.parameters()

Copy link
Author

@vmoens vmoens Aug 3, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh man this would be so great! I could definitely use that feature.

To be honest, I was thinking about using TensorDict from torchrl to pass params to functorch stateless modules. We could nest the dicts (eg d["module"]["param"] to d["module.param"]), expand the params, change device or whatever, in batch and with little or no effort since all those ops are built-in tensordict methods. I think there's a good synergy that we could get from TensorDict functorch. At the moment, TensorDict isn't torchscriptable though, I don't know how much trouble it is for you.
@nairbv @shagunsodhani

if disable_autograd_tracking:
for param in params:
param.requires_grad_(False)

buffer_container = list(extract_buffers(model_copy))
if len(buffer_container):
buffer_module_names, buffer_modules, buffer_names, buffers = zip(*buffer_container)
else:
buffer_module_names, buffer_modules, buffer_names, buffers = tuple(), tuple(), tuple(), tuple()
return (
FunctionalModuleWithBuffers(model_copy, param_names, buffer_names,
param_names_map, buffer_names_map),
FunctionalModuleWithBuffers(
model_copy,
param_module_names,
param_modules,
param_names,
buffer_module_names,
buffer_modules,
buffer_names),
params,
buffers,
)

def forward(self, params, buffers, *args, **kwargs):
# Temporarily load the state back onto self.stateless_model
old_state = _swap_state(
self.stateless_model,
self.all_names_map,
list(params) + list(buffers))
old_states = _swap_state(
self.param_modules + self.buffer_modules,
self.param_names + self.buffer_names,
list(params) + list(buffers)
)
old_params = old_states[:len(self.param_modules)]
old_buffers = old_states[len(self.param_modules):]

Comment on lines +255 to +262
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to make sure I understand why this is faster: is it because we no longer need to traverse through the module to find the submodules; we've already made the submodules directly available to swap their parameters out?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Exactly!
Instead of going through a tree of param names, we just flatten it and go through a single list of modules, one-level names and values.

try:
return self.stateless_model(*args, **kwargs)
finally:
# Remove the loaded state on self.stateless_model
_swap_state(self.stateless_model, self.all_names_map, old_state)
for module, param_name, param in zip(self.param_modules, self.param_names, old_params):
old_params.append(getattr(module, param_name))
setattr(module, param_name, param)
for module, buffer_name, buffer in zip(self.buffer_modules, self.buffer_names, old_buffers):
old_buffers.append(getattr(module, buffer_name))
setattr(module, buffer_name, buffer)

def __getstate__(self):
state = self.__dict__.copy()
state["param_modules"] = None
state["buffer_modules"] = None
return state

def __setstate__(self, state):
state["param_modules"] = []
state["buffer_modules"] = []
out = super().__setstate__(state)
for module_name in self.param_module_names:
found = False
for other_name, module in self.stateless_model.named_modules():
if other_name == module_name:
found = True
state["param_modules"].append(module)
break
if not found:
raise RuntimeError(f"module not found: {module_name}")
for module_name in self.buffer_module_names:
found = False
for other_name, module in self.stateless_model.named_modules():
if other_name == module_name:
found = True
state["buffer_modules"].append(module)
break
if not found:
raise RuntimeError(f"module not found: {module_name}")
return out


class FunctionalModule(nn.Module):
"""
This is the callable object returned by :func:`make_functional`.
This is the callable object returned by :func:`make_functional_with_buffers`.
"""

def __init__(self, stateless_model, param_names, names_map):
def __init__(self, stateless_model, param_module_names, modules, param_names):
super(FunctionalModule, self).__init__()
self.stateless_model = stateless_model
self.param_modules = modules
self.param_names = param_names
self.names_map = names_map
self.param_module_names = param_module_names

@staticmethod
def _create_from(model, disable_autograd_tracking=False):
def _create_from(model):
# TODO: We don't need to copy the model to create a stateless copy
model_copy = copy.deepcopy(model)
params, param_names, names_map = extract_weights(model_copy)
param_container = list(extract_weights(model_copy))
if len(param_container):
module_names, param_modules, param_names, params = zip(*param_container)
else:
module_names, param_modules, param_names, params = tuple(), tuple(), tuple(), tuple()
if disable_autograd_tracking:
for param in params:
param.requires_grad_(False)
return FunctionalModule(model_copy, param_names, names_map), params

return (
FunctionalModule(model_copy, module_names, param_modules, param_names),
params,
)

def forward(self, params, *args, **kwargs):
# Temporarily load the state back onto self.stateless_model
old_state = _swap_state(self.stateless_model, self.names_map, params)
old_params = _swap_state(self.param_modules, self.param_names, params)

try:
return self.stateless_model(*args, **kwargs)
finally:
# Remove the loaded state on self.stateless_model
_swap_state(self.stateless_model, self.names_map, old_state)
for module, param_name, param in zip(self.param_modules, self.param_names, old_params):
old_params.append(getattr(module, param_name))
setattr(module, param_name, param)

def __getstate__(self):
state = self.__dict__.copy()
state["param_modules"] = None
return state

def __setstate__(self, state):
state["param_modules"] = []
out = super().__setstate__(state)
for module_name in self.param_module_names:
found = False
for other_name, module in self.stateless_model.named_modules():
if other_name == module_name:
found = True
state["param_modules"].append(module)
break
if not found:
raise RuntimeError(f"module not found: {module_name}")
return out


def make_functional(model: nn.Module, disable_autograd_tracking: bool = False):
Expand Down