diff --git a/float8_experimental/float8_dynamic_linear.py b/float8_experimental/float8_dynamic_linear.py index d39985f..5a36140 100644 --- a/float8_experimental/float8_dynamic_linear.py +++ b/float8_experimental/float8_dynamic_linear.py @@ -182,7 +182,7 @@ def unwrap(t): ) def __tensor_flatten__(self): - return ["_tensor", "_pre_computed_amax"], self._mm_config + return ["_tensor"], self._mm_config @staticmethod def __tensor_unflatten__(inner_tensors, flatten_spec, outer_size, outer_stride): diff --git a/float8_experimental/float8_linear_utils.py b/float8_experimental/float8_linear_utils.py index a3f5283..0afd114 100644 --- a/float8_experimental/float8_linear_utils.py +++ b/float8_experimental/float8_linear_utils.py @@ -5,7 +5,8 @@ # LICENSE file in the root directory of this source tree. import copy import logging -import math + +# import math import warnings from enum import auto, Enum from typing import Callable, List, Optional, Type @@ -344,15 +345,17 @@ def precompute_float8_amax(module: nn.Module) -> None: weights: List[DTensor] = [float8_linear.weight for float8_linear in float8_linears] def compute_amaxes(weights: List[DTensor]): - max_weights = torch._foreach_norm(weights, ord=math.inf) + abs_weights = torch._foreach_abs(weights) # S0 + max_weights = [torch.max(a) for a in abs_weights] + # max_weights = torch._foreach_norm(weights, ord=math.inf) amax_tensor = torch.vstack(max_weights) amax_tensor = torch.clamp(amax_tensor, EPS) # R amaxes = torch.split(amax_tensor, 1) # R return amaxes if weights: - amaxes = compute_amaxes(weights) - # amaxes = torch.compile(compute_amaxes, mode="reduce-overhead")(weights) + # amaxes = compute_amaxes(weights) + amaxes = torch.compile(compute_amaxes, mode="reduce-overhead")(weights) # amaxes = torch.compile(compute_amaxes)(weights) for amax, float8_linear in zip(amaxes, float8_linears): float8_linear.weight._local_tensor._pre_computed_amax = amax._local_tensor