-
Notifications
You must be signed in to change notification settings - Fork 102
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP,POC] Faster functional modules #983
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,7 +8,6 @@ | |
import torch.nn as nn | ||
from torch import Tensor | ||
from typing import List, Tuple | ||
from .named_members_polyfill import _named_parameters, _named_buffers | ||
import copy | ||
|
||
# Utilities to make nn.Module "functional" | ||
|
@@ -56,66 +55,12 @@ def raise_parameter_tying_error(): | |
"https://github.com/pytorch/functorch/issues/446") | ||
|
||
|
||
def create_names_map(named_params, tied_named_params): | ||
""" | ||
named_params is a dictionary of tensors: {'A': A, 'B': B} | ||
tied_named_params is another dictionary of tensors {'A': A, 'B': B, 'B_tied': B} | ||
with potentially tied (or 'duplicated') tensors | ||
|
||
This function creates a mapping from the names in named_params to the | ||
names in tied_named_params: {'A': ['A'], 'B': ['B', 'B_tied']}. | ||
""" | ||
named_params = {k: v for k, v in named_params} | ||
tied_named_params = {k: v for k, v in tied_named_params} | ||
|
||
tensors_dict_keys = set(named_params.keys()) | ||
tied_tensors_dict_keys = set(tied_named_params.keys()) | ||
assert tensors_dict_keys.issubset(tied_tensors_dict_keys) | ||
|
||
tensor_to_mapping = {} | ||
for key, tensor in named_params.items(): | ||
tensor_to_mapping[tensor] = (key, []) | ||
for key, tensor in tied_named_params.items(): | ||
assert tensor in tensor_to_mapping | ||
tensor_to_mapping[tensor][1].append(key.split('.')) | ||
result = {key: value for key, value in tensor_to_mapping.values()} | ||
return result | ||
|
||
|
||
def _extract_members(mod: nn.Module, _named_members, named_members, subclass): | ||
all_named_members = tuple(_named_members(mod, remove_duplicate=False)) | ||
named_members = tuple(named_members()) | ||
names_map = create_names_map(named_members, all_named_members) | ||
|
||
# Remove all the members in the model | ||
memo = {} | ||
for name, p in all_named_members: | ||
if p not in memo: | ||
memo[p] = subclass(torch.empty_like(p, device='meta')) | ||
replacement = memo[p] | ||
_set_nested_attr(mod, name.split("."), replacement) | ||
|
||
if len(named_members) == 0: | ||
names, params = (), () | ||
else: | ||
names, params = zip(*named_members) | ||
return params, names, names_map | ||
|
||
|
||
def extract_weights(mod: nn.Module): | ||
""" | ||
This function removes all the Parameters from the model and | ||
return them as a tuple as well as their original attribute names. | ||
The weights must be re-loaded with `load_weights` before the model | ||
can be used again. | ||
Note that this function modifies the model in place and after this | ||
call, mod.parameters() will be empty. | ||
""" | ||
return _extract_members(mod, _named_parameters, mod.named_parameters, nn.Parameter) | ||
|
||
|
||
def extract_buffers(mod: nn.Module): | ||
return _extract_members(mod, _named_buffers, mod.named_buffers, lambda x: x) | ||
def extract_weights(model): | ||
for module_name, m in model.named_modules(): | ||
for param_name, p in list(m.named_parameters(recurse=False)): | ||
delattr(m, param_name) | ||
setattr(m, param_name, None) | ||
yield (module_name, m, param_name, p) | ||
|
||
|
||
def load_weights(mod: nn.Module, names: List[str], params: Tuple[Tensor, ...], as_params=False) -> None: | ||
|
@@ -131,15 +76,21 @@ def load_weights(mod: nn.Module, names: List[str], params: Tuple[Tensor, ...], a | |
_set_nested_attr(mod, name.split("."), p) | ||
|
||
|
||
def _swap_state(mod: nn.Module, names_map: List[str], elems): | ||
result = [] | ||
for (_, attr_names), elem in zip(names_map.items(), elems): | ||
for i, attr_name in enumerate(attr_names): | ||
if i == 0: | ||
result.append(_get_nested_attr(mod, attr_name)) | ||
_del_nested_attr(mod, attr_name) | ||
_set_nested_attr(mod, attr_name, elem) | ||
return result | ||
def _swap_state(param_modules, param_names, params): | ||
old_params = [] | ||
for module, param_name, param in zip(param_modules, param_names, params): | ||
old_params.append(getattr(module, param_name)) | ||
delattr(module, param_name) | ||
setattr(module, param_name, param) | ||
return old_params | ||
|
||
|
||
def extract_buffers(model): | ||
for module_name, m in model.named_modules(): | ||
for buffer_name, b in list(m.named_buffers(recurse=False)): | ||
delattr(m, buffer_name) | ||
setattr(m, buffer_name, None) | ||
yield (module_name, m, buffer_name, b) | ||
|
||
|
||
def load_buffers(mod: nn.Module, names: List[str], buffers: Tuple[Tensor, ...], as_params=False) -> None: | ||
|
@@ -194,7 +145,7 @@ def make_functional_deprecated_v1(model: nn.Module): | |
if len(buffers) > 0: | ||
raise RuntimeError('make_functional_deprecated_v1(model): `model` has buffers. Please use ' | ||
'make_functional_with_buffers_deprecated_v1(model) instead.') | ||
weights, descriptors, _ = extract_weights(model) | ||
weights, descriptors = extract_weights(model) | ||
|
||
def fun(weights, data): | ||
mutable_model = copy.deepcopy(model) | ||
|
@@ -241,79 +192,175 @@ def fun(weights, buffers, data): | |
return weights, buffers, fun, weight_descriptors, buf_descriptors | ||
|
||
|
||
def make_split_names(lst): | ||
return [name.split('.') for name in lst] | ||
|
||
|
||
class FunctionalModuleWithBuffers(nn.Module): | ||
""" | ||
This is the callable object returned by :func:`make_functional_with_buffers`. | ||
""" | ||
|
||
def __init__(self, stateless_model, param_names, buffer_names, | ||
param_names_map, buffer_names_map): | ||
def __init__( | ||
self, | ||
stateless_model, | ||
param_module_names, | ||
param_modules, | ||
param_names, | ||
buffer_module_names, | ||
buffer_modules, | ||
buffer_names | ||
): | ||
super(FunctionalModuleWithBuffers, self).__init__() | ||
self.stateless_model = stateless_model | ||
self.param_module_names = param_module_names | ||
self.param_modules = param_modules | ||
self.param_names = param_names | ||
self.buffer_module_names = buffer_module_names | ||
self.buffer_modules = buffer_modules | ||
self.buffer_names = buffer_names | ||
|
||
self.all_names_map = dict(param_names_map) | ||
self.all_names_map.update(buffer_names_map) | ||
|
||
@staticmethod | ||
def _create_from(model, disable_autograd_tracking=False): | ||
def _create_from(model): | ||
# TODO: We don't need to copy the model to create a stateless copy | ||
model_copy = copy.deepcopy(model) | ||
params, param_names, param_names_map = extract_weights(model_copy) | ||
buffers, buffer_names, buffer_names_map = extract_buffers(model_copy) | ||
param_container = list(extract_weights(model_copy)) | ||
if len(param_container): | ||
param_module_names, param_modules, param_names, params = zip(*param_container) | ||
else: | ||
param_module_names, param_modules, param_names, params = tuple(), tuple(), tuple(), tuple() | ||
Comment on lines
+229
to
+231
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Previously, we guaranteed that params is returned in the same order as what gets returned by original_model.parameters(). After this change, is that still true? (Side note) To be honest, we've been thinking of changing the API so that params isn't returned as a flat list; instead we probably want to return some sort of dictionary or object so that one can easily figure out which There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh man this would be so great! I could definitely use that feature. To be honest, I was thinking about using TensorDict from torchrl to pass params to functorch stateless modules. We could nest the dicts (eg d["module"]["param"] to d["module.param"]), expand the params, change device or whatever, in batch and with little or no effort since all those ops are built-in tensordict methods. I think there's a good synergy that we could get from TensorDict functorch. At the moment, TensorDict isn't torchscriptable though, I don't know how much trouble it is for you. |
||
if disable_autograd_tracking: | ||
for param in params: | ||
param.requires_grad_(False) | ||
|
||
buffer_container = list(extract_buffers(model_copy)) | ||
if len(buffer_container): | ||
buffer_module_names, buffer_modules, buffer_names, buffers = zip(*buffer_container) | ||
else: | ||
buffer_module_names, buffer_modules, buffer_names, buffers = tuple(), tuple(), tuple(), tuple() | ||
return ( | ||
FunctionalModuleWithBuffers(model_copy, param_names, buffer_names, | ||
param_names_map, buffer_names_map), | ||
FunctionalModuleWithBuffers( | ||
model_copy, | ||
param_module_names, | ||
param_modules, | ||
param_names, | ||
buffer_module_names, | ||
buffer_modules, | ||
buffer_names), | ||
params, | ||
buffers, | ||
) | ||
|
||
def forward(self, params, buffers, *args, **kwargs): | ||
# Temporarily load the state back onto self.stateless_model | ||
old_state = _swap_state( | ||
self.stateless_model, | ||
self.all_names_map, | ||
list(params) + list(buffers)) | ||
old_states = _swap_state( | ||
self.param_modules + self.buffer_modules, | ||
self.param_names + self.buffer_names, | ||
list(params) + list(buffers) | ||
) | ||
old_params = old_states[:len(self.param_modules)] | ||
old_buffers = old_states[len(self.param_modules):] | ||
|
||
Comment on lines
+255
to
+262
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just to make sure I understand why this is faster: is it because we no longer need to traverse through the module to find the submodules; we've already made the submodules directly available to swap their parameters out? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Exactly! |
||
try: | ||
return self.stateless_model(*args, **kwargs) | ||
finally: | ||
# Remove the loaded state on self.stateless_model | ||
_swap_state(self.stateless_model, self.all_names_map, old_state) | ||
for module, param_name, param in zip(self.param_modules, self.param_names, old_params): | ||
old_params.append(getattr(module, param_name)) | ||
setattr(module, param_name, param) | ||
for module, buffer_name, buffer in zip(self.buffer_modules, self.buffer_names, old_buffers): | ||
old_buffers.append(getattr(module, buffer_name)) | ||
setattr(module, buffer_name, buffer) | ||
|
||
def __getstate__(self): | ||
state = self.__dict__.copy() | ||
state["param_modules"] = None | ||
state["buffer_modules"] = None | ||
return state | ||
|
||
def __setstate__(self, state): | ||
state["param_modules"] = [] | ||
state["buffer_modules"] = [] | ||
out = super().__setstate__(state) | ||
for module_name in self.param_module_names: | ||
found = False | ||
for other_name, module in self.stateless_model.named_modules(): | ||
if other_name == module_name: | ||
found = True | ||
state["param_modules"].append(module) | ||
break | ||
if not found: | ||
raise RuntimeError(f"module not found: {module_name}") | ||
for module_name in self.buffer_module_names: | ||
found = False | ||
for other_name, module in self.stateless_model.named_modules(): | ||
if other_name == module_name: | ||
found = True | ||
state["buffer_modules"].append(module) | ||
break | ||
if not found: | ||
raise RuntimeError(f"module not found: {module_name}") | ||
return out | ||
|
||
|
||
class FunctionalModule(nn.Module): | ||
""" | ||
This is the callable object returned by :func:`make_functional`. | ||
This is the callable object returned by :func:`make_functional_with_buffers`. | ||
""" | ||
|
||
def __init__(self, stateless_model, param_names, names_map): | ||
def __init__(self, stateless_model, param_module_names, modules, param_names): | ||
super(FunctionalModule, self).__init__() | ||
self.stateless_model = stateless_model | ||
self.param_modules = modules | ||
self.param_names = param_names | ||
self.names_map = names_map | ||
self.param_module_names = param_module_names | ||
|
||
@staticmethod | ||
def _create_from(model, disable_autograd_tracking=False): | ||
def _create_from(model): | ||
# TODO: We don't need to copy the model to create a stateless copy | ||
model_copy = copy.deepcopy(model) | ||
params, param_names, names_map = extract_weights(model_copy) | ||
param_container = list(extract_weights(model_copy)) | ||
if len(param_container): | ||
module_names, param_modules, param_names, params = zip(*param_container) | ||
else: | ||
module_names, param_modules, param_names, params = tuple(), tuple(), tuple(), tuple() | ||
if disable_autograd_tracking: | ||
for param in params: | ||
param.requires_grad_(False) | ||
return FunctionalModule(model_copy, param_names, names_map), params | ||
|
||
return ( | ||
FunctionalModule(model_copy, module_names, param_modules, param_names), | ||
params, | ||
) | ||
|
||
def forward(self, params, *args, **kwargs): | ||
# Temporarily load the state back onto self.stateless_model | ||
old_state = _swap_state(self.stateless_model, self.names_map, params) | ||
old_params = _swap_state(self.param_modules, self.param_names, params) | ||
|
||
try: | ||
return self.stateless_model(*args, **kwargs) | ||
finally: | ||
# Remove the loaded state on self.stateless_model | ||
_swap_state(self.stateless_model, self.names_map, old_state) | ||
for module, param_name, param in zip(self.param_modules, self.param_names, old_params): | ||
old_params.append(getattr(module, param_name)) | ||
setattr(module, param_name, param) | ||
|
||
def __getstate__(self): | ||
state = self.__dict__.copy() | ||
state["param_modules"] = None | ||
return state | ||
|
||
def __setstate__(self, state): | ||
state["param_modules"] = [] | ||
out = super().__setstate__(state) | ||
for module_name in self.param_module_names: | ||
found = False | ||
for other_name, module in self.stateless_model.named_modules(): | ||
if other_name == module_name: | ||
found = True | ||
state["param_modules"].append(module) | ||
break | ||
if not found: | ||
raise RuntimeError(f"module not found: {module_name}") | ||
return out | ||
|
||
|
||
def make_functional(model: nn.Module, disable_autograd_tracking: bool = False): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The reason we previously used create_names_map was for parameter tying. If someone creates a module that looks like:
then
fmodel, params = make_functional(Foo())
returns 2 Tensors (self.linear.weight and self.bias) instead of 3 Tensors. When the user callsfmodel([w, b], x)
, then b gets loaded to self.bias and self.linear.bias and w gets loaded to self.linear.weight.Under the new strategy, it seems like
params
would have 3 tensors: [self.bias, self.linear.weight, self.linear.bias].In general I'm not really sure what the interaction between parameter tying and make_functional should be. Thoughts?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see your point. If we want to keep things as they are we could link params to a list of modules and a list of names (instead of a single module and a single name). That will come with a slight overhead though...
It's the kind of design choice where you will always make someone unhappy (there will be someone out there that wants multiple copies of the same param), but it's probably not the majority of users.