Skip to content
This repository has been archived by the owner on Aug 7, 2024. It is now read-only.

Commit

Permalink
clarify public API of float8_experimental (#330)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #330

Makes the following functions public:
* convert_to_float8_training and all of its configuration
* linear_requires_sync
* sync_float8_amax_and_scale_history
* precompute_float8_dynamic_scale_for_fsdp

Everything else is private. The fbsource counterpart of this PR will
remove usage of private APIs.

Reviewed By: weifengpy

Differential Revision: D60195666

fbshipit-source-id: 2e99475cf7f852b91b4c96687a7f229a2c8b3adf
  • Loading branch information
vkuzo authored and facebook-github-bot committed Jul 25, 2024
1 parent da487a3 commit a6cef5a
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 8 deletions.
8 changes: 4 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,10 @@ We provide two per-tensor scaling strategies: dynamic and delayed. See https://
This is the most accurate recipe as every tensor is scaled dynamically.

```python
from float8_experimental.float8_linear_utils import (
from float8_experimental import (
convert_to_float8_training,
precompute_float8_dynamic_scale_for_fsdp,
)
from float8_experimental.fsdp_utils import precompute_float8_dynamic_scale_for_fsdp

# create model
m = Model(...)
Expand Down Expand Up @@ -82,11 +82,11 @@ for _ in range(N_ITER):
This is theoretically the most performant recipe as it minimizes memory reads.

```python
from float8_experimental.float8_linear_utils import (
from float8_experimental import (
convert_to_float8_training,
sync_float8_amax_and_scale_history,
TensorScalingType,
)
from float8_experimental.float8_linear import TensorScalingType

# create model
m = Model(...)
Expand Down
14 changes: 10 additions & 4 deletions float8_experimental/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,18 @@
TensorScalingType,
)
from float8_experimental.float8_linear import Float8Linear
from float8_experimental.float8_linear_utils import convert_to_float8_training
from float8_experimental.float8_linear_utils import (
convert_to_float8_training,
linear_requires_sync,
sync_float8_amax_and_scale_history,
)
from float8_experimental.float8_tensor import (
Float8Tensor,
GemmInputRole,
LinearMMConfig,
ScaledMMConfig,
)
from float8_experimental.fsdp_utils import precompute_float8_dynamic_scale_for_fsdp

# Needed to load Float8Tensor with weights_only = True
from torch.serialization import add_safe_globals
Expand All @@ -30,7 +35,8 @@
"Float8TensorCastConfig",
# top level UX
"convert_to_float8_training",
# TODO(future): remove Float8Tensor and Float8Linear from public API
"Float8Tensor",
"Float8Linear",
"linear_requires_sync",
"sync_float8_amax_and_scale_history",
"precompute_float8_dynamic_scale_for_fsdp",
# note: Float8Tensor and Float8Linear are not public APIs
]
4 changes: 4 additions & 0 deletions float8_experimental/float8_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,10 @@ def __init__(self, history_len: int = 16, scale_fn_name: str = "max"):

class Float8Linear(torch.nn.Linear):
"""
Note: this is **not** a public API and is only intended to be used
inside of this repository. Please file an issue if you would benefit
from this being a public API.
A wrapper around a `torch.nn.Linear` module which does fp8 compute, and tracks
scales in way friendly to delayed scaling.
"""
Expand Down
4 changes: 4 additions & 0 deletions float8_experimental/float8_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,10 @@ def backward(ctx, g):

class Float8Tensor(torch.Tensor):
"""
Note: this is **not** a public API and is only intended to be used
inside of this repository. Please file an issue if you would benefit
from this being a public API.
A Python-only Float8 tensor subclass. Contains:
* `_data`: the underlying e4m3 or e5m2 data
* `_scale`: the scale used to scale the original fp32 tensor. We multiply
Expand Down

0 comments on commit a6cef5a

Please sign in to comment.