Skip to content
This repository has been archived by the owner on Aug 7, 2024. It is now read-only.

Commit

Permalink
Update on "[bc breaking] change x, w, dL_dY variable names to input, …
Browse files Browse the repository at this point in the history
…weight, grad_output"

Summary:

The following naming scheme matches the rest of PyTorch better:

```
// forward
output = input @ weight_t
// backward
grad_input = grad_output @ weight
grad_weight = input_t @ grad_output
```

This PR changes all the previous references to `x`, `w`, `dL_dY` to
match the naming scheme above.

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

[ghstack-poisoned]
  • Loading branch information
vkuzo committed Jul 22, 2024
2 parents c29fcbe + 0e81f87 commit d7bc4fe
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 22 deletions.
10 changes: 5 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,19 +43,19 @@ from float8_experimental.fsdp_utils import precompute_float8_dynamic_scale_for_f
# create model
m = Model(...)

# optional: filter layers from being eligible for float8 conversion
def layer_filter_fn(fqn: str, mod: torch.nn.Module):
# don't convert the output layer
# optional: filter modules from being eligible for float8 conversion
def module_filter_fn(fqn: str, mod: torch.nn.Module):
# don't convert the output module
if fqn == "output":
return False
# don't convert linear layers with weight dimensions not divisible by 16
# don't convert linear modules with weight dimensions not divisible by 16
if isinstance(mod, torch.nn.Linear):
if mod.in_features % 16 != 0 or mod.out_features % 16 != 0:
return False
return True

# convert all `torch.nn.Linear` modules to `Float8Linear`
swap_linear_with_float8_linear(m, layer_filter_fn=layer_filter_fn)
swap_linear_with_float8_linear(m, module_filter_fn=module_filter_fn)

# optional: use FSDP
model = FSDP(model, use_orig_params=True)
Expand Down
18 changes: 8 additions & 10 deletions float8_experimental/float8_linear_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def swap_linear_layers(
module: nn.Module,
from_float_func: Callable[[nn.Linear], nn.Linear],
*,
layer_filter_fn: Optional[Callable[[str, nn.Module], bool]] = None,
module_filter_fn: Optional[Callable[[str, nn.Module], bool]] = None,
) -> Optional[nn.Module]:
"""
Generic function to swap linear layers in a module with a new type of linear layer.
Expand All @@ -75,17 +75,15 @@ def swap_linear_layers(
Args:
module: Module to modify.
from_float_func: Function that accepts a linear layer and returns a new type of linear layer.
layer_filter_fn: If specified, only the modules that
module_filter_fn: If specified, only the `torch.nn.Linear` subclasses that
that pass the filter function will be swapped. The inputs to the
filter function are the FQN and module instance.
Returns:
nn.Module: The modified module with swapped linear layers.
"""
if isinstance(module, nn.Linear) and (
# linear_layer_filter is None or linear_layer_filter(module)
layer_filter_fn is None
or layer_filter_fn("", module)
module_filter_fn is None or module_filter_fn("", module)
):
if len(list(module.children())) > 0:
raise AssertionError(
Expand Down Expand Up @@ -115,8 +113,8 @@ def post_order_traversal(

if isinstance(module, nn.Linear) and (
# linear_layer_filter is None or linear_layer_filter(module)
layer_filter_fn is None
or layer_filter_fn(cur_fqn, module)
module_filter_fn is None
or module_filter_fn(cur_fqn, module)
):
assert (
parent_module is not None
Expand All @@ -133,7 +131,7 @@ def swap_linear_with_float8_linear(
module: nn.Module,
*,
emulate: bool = False,
layer_filter_fn: Optional[Callable[[str, nn.Module], bool]] = None,
module_filter_fn: Optional[Callable[[str, nn.Module], bool]] = None,
scaling_type_input: TensorScalingType = TensorScalingType.DYNAMIC,
scaling_type_weight: TensorScalingType = TensorScalingType.DYNAMIC,
scaling_type_grad_output: TensorScalingType = TensorScalingType.DYNAMIC,
Expand All @@ -144,7 +142,7 @@ def swap_linear_with_float8_linear(
Args:
module: Module to modify.
emulate: If True, emulation is used instead of hardware accelerated gemm
layer_filter_fn: If specified, only the modules that
module_filter_fn: If specified, only the `torch.nn.Linear` subclasses that
that pass the filter function will be swapped. The inputs to the
filter function are the FQN and module instance.
scaling_type_input (TensorScalingType): scaling type for `input`
Expand All @@ -164,7 +162,7 @@ def swap_linear_with_float8_linear(
return swap_linear_layers(
module,
from_float,
layer_filter_fn=layer_filter_fn,
module_filter_fn=module_filter_fn,
)


Expand Down
6 changes: 3 additions & 3 deletions float8_experimental/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ def quantize_to_float8(
module: nn.Module,
quant_config: QuantConfig,
*,
layer_filter_fn: Optional[Callable[[str, nn.Module], bool]] = None,
module_filter_fn: Optional[Callable[[str, nn.Module], bool]] = None,
use_fast_accum: bool = True,
) -> Optional[nn.Module]:
"""
Expand All @@ -222,7 +222,7 @@ def quantize_to_float8(
Args:
module (nn.Module): The module to modify.
quant_config (QuantConfig): Quantization configuration for Float8 conversion.
layer_filter_fn: If specified, only the modules that
module_filter_fn: If specified, only the `torch.nn.Linear` subclasses that
that pass the filter function will be swapped. The inputs to the
filter function are the FQN and module instance.
use_fast_accum : Whether to enable fast accumulation for the Float8InferenceLinear. Defaults to True.
Expand All @@ -236,5 +236,5 @@ def quantize_to_float8(
return swap_linear_layers(
module,
lambda m: Float8InferenceLinear.from_float(m, quant_config, use_fast_accum),
layer_filter_fn=layer_filter_fn,
module_filter_fn=module_filter_fn,
)
8 changes: 4 additions & 4 deletions test/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -640,7 +640,7 @@ def __init__(self, dim: int):

size_limit = 32

def layer_filter_fn(fqn, mod):
def module_filter_fn(fqn, mod):
return (
mod.in_features >= size_limit
and mod.out_features >= size_limit
Expand All @@ -651,7 +651,7 @@ def layer_filter_fn(fqn, mod):
model = swap_linear_with_float8_linear(
model,
emulate=True,
layer_filter_fn=layer_filter_fn,
module_filter_fn=module_filter_fn,
)
# in_features=8, out_features=32, 8 is less than 32.
self.assertNotIsInstance(model[0].lin1, Float8Linear)
Expand All @@ -672,14 +672,14 @@ def __init__(self, dim: int):
self.lin2 = nn.Linear(4 * dim, dim)

model = nn.Sequential(MLP(3), nn.Linear(3, 3), MLP(3))
layer_filter_fn = lambda fqn, mod: fqn not in [
module_filter_fn = lambda fqn, mod: fqn not in [
"0.lin2",
"2.lin1",
]
model = swap_linear_with_float8_linear(
model,
emulate=True,
layer_filter_fn=layer_filter_fn,
module_filter_fn=module_filter_fn,
)
self.assertTrue(type(model[0].lin1) is Float8Linear)
self.assertTrue(type(model[0].lin2) is nn.Linear)
Expand Down

0 comments on commit d7bc4fe

Please sign in to comment.