diff --git a/README.md b/README.md index cb744f1..9436696 100644 --- a/README.md +++ b/README.md @@ -43,19 +43,19 @@ from float8_experimental.fsdp_utils import precompute_float8_dynamic_scale_for_f # create model m = Model(...) -# optional: filter layers from being eligible for float8 conversion -def layer_filter_fn(fqn: str, mod: torch.nn.Module): - # don't convert the output layer +# optional: filter modules from being eligible for float8 conversion +def module_filter_fn(fqn: str, mod: torch.nn.Module): + # don't convert the output module if fqn == "output": return False - # don't convert linear layers with weight dimensions not divisible by 16 + # don't convert linear modules with weight dimensions not divisible by 16 if isinstance(mod, torch.nn.Linear): if mod.in_features % 16 != 0 or mod.out_features % 16 != 0: return False return True # convert all `torch.nn.Linear` modules to `Float8Linear` -swap_linear_with_float8_linear(m, layer_filter_fn=layer_filter_fn) +swap_linear_with_float8_linear(m, module_filter_fn=module_filter_fn) # optional: use FSDP model = FSDP(model, use_orig_params=True) diff --git a/float8_experimental/float8_linear_utils.py b/float8_experimental/float8_linear_utils.py index caac4a8..8140baa 100644 --- a/float8_experimental/float8_linear_utils.py +++ b/float8_experimental/float8_linear_utils.py @@ -63,7 +63,7 @@ def swap_linear_layers( module: nn.Module, from_float_func: Callable[[nn.Linear], nn.Linear], *, - layer_filter_fn: Optional[Callable[[str, nn.Module], bool]] = None, + module_filter_fn: Optional[Callable[[str, nn.Module], bool]] = None, ) -> Optional[nn.Module]: """ Generic function to swap linear layers in a module with a new type of linear layer. @@ -75,7 +75,7 @@ def swap_linear_layers( Args: module: Module to modify. from_float_func: Function that accepts a linear layer and returns a new type of linear layer. - layer_filter_fn: If specified, only the modules that + module_filter_fn: If specified, only the `torch.nn.Linear` subclasses that that pass the filter function will be swapped. The inputs to the filter function are the FQN and module instance. @@ -83,9 +83,7 @@ def swap_linear_layers( nn.Module: The modified module with swapped linear layers. """ if isinstance(module, nn.Linear) and ( - # linear_layer_filter is None or linear_layer_filter(module) - layer_filter_fn is None - or layer_filter_fn("", module) + module_filter_fn is None or module_filter_fn("", module) ): if len(list(module.children())) > 0: raise AssertionError( @@ -115,8 +113,8 @@ def post_order_traversal( if isinstance(module, nn.Linear) and ( # linear_layer_filter is None or linear_layer_filter(module) - layer_filter_fn is None - or layer_filter_fn(cur_fqn, module) + module_filter_fn is None + or module_filter_fn(cur_fqn, module) ): assert ( parent_module is not None @@ -133,7 +131,7 @@ def swap_linear_with_float8_linear( module: nn.Module, *, emulate: bool = False, - layer_filter_fn: Optional[Callable[[str, nn.Module], bool]] = None, + module_filter_fn: Optional[Callable[[str, nn.Module], bool]] = None, scaling_type_input: TensorScalingType = TensorScalingType.DYNAMIC, scaling_type_weight: TensorScalingType = TensorScalingType.DYNAMIC, scaling_type_grad_output: TensorScalingType = TensorScalingType.DYNAMIC, @@ -144,7 +142,7 @@ def swap_linear_with_float8_linear( Args: module: Module to modify. emulate: If True, emulation is used instead of hardware accelerated gemm - layer_filter_fn: If specified, only the modules that + module_filter_fn: If specified, only the `torch.nn.Linear` subclasses that that pass the filter function will be swapped. The inputs to the filter function are the FQN and module instance. scaling_type_input (TensorScalingType): scaling type for `input` @@ -164,7 +162,7 @@ def swap_linear_with_float8_linear( return swap_linear_layers( module, from_float, - layer_filter_fn=layer_filter_fn, + module_filter_fn=module_filter_fn, ) diff --git a/float8_experimental/inference.py b/float8_experimental/inference.py index 342b6a1..814ce58 100644 --- a/float8_experimental/inference.py +++ b/float8_experimental/inference.py @@ -209,7 +209,7 @@ def quantize_to_float8( module: nn.Module, quant_config: QuantConfig, *, - layer_filter_fn: Optional[Callable[[str, nn.Module], bool]] = None, + module_filter_fn: Optional[Callable[[str, nn.Module], bool]] = None, use_fast_accum: bool = True, ) -> Optional[nn.Module]: """ @@ -222,7 +222,7 @@ def quantize_to_float8( Args: module (nn.Module): The module to modify. quant_config (QuantConfig): Quantization configuration for Float8 conversion. - layer_filter_fn: If specified, only the modules that + module_filter_fn: If specified, only the `torch.nn.Linear` subclasses that that pass the filter function will be swapped. The inputs to the filter function are the FQN and module instance. use_fast_accum : Whether to enable fast accumulation for the Float8InferenceLinear. Defaults to True. @@ -236,5 +236,5 @@ def quantize_to_float8( return swap_linear_layers( module, lambda m: Float8InferenceLinear.from_float(m, quant_config, use_fast_accum), - layer_filter_fn=layer_filter_fn, + module_filter_fn=module_filter_fn, ) diff --git a/test/test_base.py b/test/test_base.py index 1d2f994..6841c6f 100644 --- a/test/test_base.py +++ b/test/test_base.py @@ -640,7 +640,7 @@ def __init__(self, dim: int): size_limit = 32 - def layer_filter_fn(fqn, mod): + def module_filter_fn(fqn, mod): return ( mod.in_features >= size_limit and mod.out_features >= size_limit @@ -651,7 +651,7 @@ def layer_filter_fn(fqn, mod): model = swap_linear_with_float8_linear( model, emulate=True, - layer_filter_fn=layer_filter_fn, + module_filter_fn=module_filter_fn, ) # in_features=8, out_features=32, 8 is less than 32. self.assertNotIsInstance(model[0].lin1, Float8Linear) @@ -672,14 +672,14 @@ def __init__(self, dim: int): self.lin2 = nn.Linear(4 * dim, dim) model = nn.Sequential(MLP(3), nn.Linear(3, 3), MLP(3)) - layer_filter_fn = lambda fqn, mod: fqn not in [ + module_filter_fn = lambda fqn, mod: fqn not in [ "0.lin2", "2.lin1", ] model = swap_linear_with_float8_linear( model, emulate=True, - layer_filter_fn=layer_filter_fn, + module_filter_fn=module_filter_fn, ) self.assertTrue(type(model[0].lin1) is Float8Linear) self.assertTrue(type(model[0].lin2) is nn.Linear)