Skip to content

Commit

Permalink
Small fixes to Float8Tensor (#1225)
Browse files Browse the repository at this point in the history
* Fixes to Float8Tensor

Signed-off-by: Przemyslaw Tredak <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: Przemyslaw Tredak <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
ptrendx and pre-commit-ci[bot] authored Oct 10, 2024
1 parent 85e60e6 commit 3b89c36
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 37 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,4 @@ develop-eggs/
dist/
downloads/
.pytest_cache/
compile_commands.json
70 changes: 33 additions & 37 deletions transformer_engine/pytorch/tensor/float8_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,12 +126,9 @@ def forward(

# Check scale
if scale is None and fp8_meta is None:
scale = 1
scale = torch.full([1], 1, dtype=torch.float32, device=device)
if scale is not None:
if isinstance(scale, torch.Tensor):
scale = scale.to(device=device, dtype=torch.float32)
else:
scale = torch.full([1], scale, dtype=torch.float32, device=device)
scale = scale.to(device=device, dtype=torch.float32)

# Check scale-inverse
if scale_inv is None:
Expand Down Expand Up @@ -335,6 +332,18 @@ class Float8Tensor(QuantizedTensor):
"""

_data: torch.Tensor
_fp8_attrs: Dict[str, Any]
_fp8_meta: Optional[Dict[str, Any]]
_fp8_meta_forward: bool
_fp8_meta_index: Optional[int]
_fp8_dtype: TE_DType
_scale_inv: torch.Tensor

# FP8 transpose cache
_transpose: Optional[torch.Tensor]
_transpose_invalid: bool

def __new__(
cls,
*,
Expand Down Expand Up @@ -371,13 +380,12 @@ def __new__(
requires_grad=requires_grad,
device=data.device,
)
self._data: torch.Tensor = data
self._data = data

# Initialize dict of class attributes
# Note: We store FP8 attributes in a dictionary so we can
# share them between tensors with the same data, e.g. detached
# tensors.
self._fp8_attrs: dict
if fp8_attrs is None:
self._fp8_attrs = {}
else:
Expand All @@ -390,16 +398,16 @@ def __new__(
"To initialize Float8Tensor with FP8 meta tensors, "
"the FP8 meta tensor index must also be provided"
)
self._fp8_meta: Optional[Dict[str, Any]] = fp8_meta
self._fp8_meta_forward: bool = fp8_meta_forward
self._fp8_meta_index: Optional[int] = fp8_meta_index
self._fp8_meta = fp8_meta
self._fp8_meta_forward = fp8_meta_forward
self._fp8_meta_index = fp8_meta_index

# FP8 dtype
assert fp8_dtype in (
TE_DType.kFloat8E4M3,
TE_DType.kFloat8E5M2,
), f"Unsupported fp8_dtype {fp8_dtype}."
self._fp8_dtype: TE_DType = fp8_dtype
self._fp8_dtype = fp8_dtype

# FP8 scale-inverse
if fp8_scale_inv is None and self._fp8_meta is not None:
Expand All @@ -412,13 +420,6 @@ def __new__(
raise ValueError(
"Attempted to initialize Float8Tensor without specifying scale-inverse"
)
if not isinstance(fp8_scale_inv, torch.Tensor):
fp8_scale_inv = torch.full(
[1],
fp8_scale_inv,
dtype=torch.float32,
device=self._data.device,
)
if fp8_scale_inv.numel() != 1:
raise ValueError(
"Attempted to initialize Float8Tensor with invalid scale-inverse tensor"
Expand All @@ -433,11 +434,11 @@ def __new__(
device=self._data.device,
dtype=torch.float32,
)
self._scale_inv: Optional[torch.Tensor] = fp8_scale_inv
self._scale_inv = fp8_scale_inv

# FP8 transpose cache
self._transpose: Optional[Float8Tensor] = data_transpose
self._transpose_invalid: bool = self._transpose is None
self._transpose = data_transpose
self._transpose_invalid = self._transpose is None

return self

Expand Down Expand Up @@ -477,7 +478,7 @@ def __repr__(self):
")"
)

def dequantize(self, dtype: Optional[torch.dtype] = None) -> torch.Tensor:
def dequantize(self, *, dtype: Optional[torch.dtype] = None) -> torch.Tensor:

# Convert PyTorch dtype to TE dtype
if dtype is None:
Expand Down Expand Up @@ -603,11 +604,8 @@ def quantize_(

# Make sure FP8 scaling factors are in expected format
if scale is not None:
if isinstance(scale, torch.Tensor):
if not devices_match(scale.device, dst.device) or scale.dtype != torch.float32:
scale = scale.to(device=dst.device, dtype=torch.float32)
else:
scale = torch.full([1], scale, dtype=torch.float32, device=dst.device)
if not devices_match(scale.device, dst.device) or scale.dtype != torch.float32:
scale = scale.to(device=dst.device, dtype=torch.float32)
if amax is not None:
while amax.dim() < 2:
amax = amax.unsqueeze(0)
Expand Down Expand Up @@ -781,23 +779,21 @@ def transpose_2d(
fill_cache = False

# Need to compute transpose if cache is invalid
need_compute = force_compute
if self._transpose is None:
need_compute = True
elif self._transpose_invalid:
need_compute = True

# Need to apply transpose kernel if noop flag is applied
if noop_flag is not None:
need_compute = True
need_compute = (
force_compute
or (self._transpose is None)
or self._transpose_invalid
or (noop_flag is not None)
)

# Return cached transpose if possible
if not need_compute:
assert self._transpose is not None
return self._transpose

# Allocate output if needed
data = self._data.contiguous().reshape(-1, self.size(-1))
out = self._transpose
out: Optional[torch.Tensor] = self._transpose
if out is None:
out = torch.empty(
(data.size(1), data.size(0)),
Expand Down

0 comments on commit 3b89c36

Please sign in to comment.