Skip to content
This repository has been archived by the owner on Aug 7, 2024. It is now read-only.

Commit

Permalink
torch.compile works
Browse files Browse the repository at this point in the history
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
weifengpy committed Jun 6, 2024
1 parent 6f244a2 commit dc5eab0
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 5 deletions.
2 changes: 1 addition & 1 deletion float8_experimental/float8_dynamic_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ def unwrap(t):
)

def __tensor_flatten__(self):
return ["_tensor", "_pre_computed_amax"], self._mm_config
return ["_tensor"], self._mm_config

@staticmethod
def __tensor_unflatten__(inner_tensors, flatten_spec, outer_size, outer_stride):
Expand Down
11 changes: 7 additions & 4 deletions float8_experimental/float8_linear_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
# LICENSE file in the root directory of this source tree.
import copy
import logging
import math

# import math
import warnings
from enum import auto, Enum
from typing import Callable, List, Optional, Type
Expand Down Expand Up @@ -344,15 +345,17 @@ def precompute_float8_amax(module: nn.Module) -> None:
weights: List[DTensor] = [float8_linear.weight for float8_linear in float8_linears]

def compute_amaxes(weights: List[DTensor]):
max_weights = torch._foreach_norm(weights, ord=math.inf)
abs_weights = torch._foreach_abs(weights) # S0
max_weights = [torch.max(a) for a in abs_weights]
# max_weights = torch._foreach_norm(weights, ord=math.inf)
amax_tensor = torch.vstack(max_weights)
amax_tensor = torch.clamp(amax_tensor, EPS) # R
amaxes = torch.split(amax_tensor, 1) # R
return amaxes

if weights:
amaxes = compute_amaxes(weights)
# amaxes = torch.compile(compute_amaxes, mode="reduce-overhead")(weights)
# amaxes = compute_amaxes(weights)
amaxes = torch.compile(compute_amaxes, mode="reduce-overhead")(weights)
# amaxes = torch.compile(compute_amaxes)(weights)
for amax, float8_linear in zip(amaxes, float8_linears):
float8_linear.weight._local_tensor._pre_computed_amax = amax._local_tensor
Expand Down

0 comments on commit dc5eab0

Please sign in to comment.