Skip to content
This repository has been archived by the owner on Aug 7, 2024. It is now read-only.

Updates with new scaled-mm api #284

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions benchmarks/bench_linear_float8.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@

import torch
import torch.utils.benchmark as benchmark
from float8_experimental.float8_dynamic_linear import Float8DynamicLinear
from float8_experimental.float8_linear import Float8Linear
from float8_experimental.float8_linear_utils import (
get_float8_linear,
LinearType,
Expand Down
6 changes: 5 additions & 1 deletion benchmarks/bench_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,11 @@ def run(n_limit: Optional[int] = None):
B = torch.zeros(K, N, device=device, dtype=d2).t().contiguous().t()

def do_matmul(A, B):
return torch._scaled_mm(A, B, out_dtype=d3, use_fast_accum=False)
scale_a = torch.tensor([1], device=device)
scale_b = torch.tensor([1], device=device)
return torch._scaled_mm(
A, B, scale_a, scale_b, out_dtype=d3, use_fast_accum=False
)

fp8_time_sec, fp8_tops_sec, fp8_pct_top_peak = do_benchmarks(
tops, dtype_to_peak_tops[d1], do_matmul, A, B
Expand Down
7 changes: 3 additions & 4 deletions float8_experimental/float8_aten_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@

import torch

from float8_experimental.float8_utils import tensor_to_amax
from torch.library import Library


Expand All @@ -26,7 +25,7 @@ def mm_float8_emulated(
m2_fp32 = m2.float() / s2
m3_fp32 = torch.mm(m1_fp32, m2_fp32)

return m3_fp32.to(dtype3), tensor_to_amax(m3_fp32)
return m3_fp32.to(dtype3)


#
Expand All @@ -38,7 +37,7 @@ def mm_float8_emulated(
lib = Library("aten", "FRAGMENT")

lib.define(
"mm_float8_emulated(Tensor m1, Tensor s1, Tensor m2, Tensor s2, ScalarType dtype3) -> (Tensor, Tensor)"
"mm_float8_emulated(Tensor m1, Tensor s1, Tensor m2, Tensor s2, ScalarType dtype3) -> Tensor"
)
lib.impl("mm_float8_emulated", mm_float8_emulated, "CPU")
lib.impl("mm_float8_emulated", mm_float8_emulated, "CUDA")
Expand All @@ -47,4 +46,4 @@ def mm_float8_emulated(
@torch.library.impl(lib, "mm_float8_emulated", "Meta")
def _mm_float8_emulated_meta(m1, s1, m2, s2, dtype3):
out = torch.mm(m1.float(), m2.float()).to(dtype3)
return out, torch.empty(1, device="meta")
return out
8 changes: 4 additions & 4 deletions float8_experimental/float8_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,8 +147,8 @@ def float8_mm(aten_op, args, kwargs=None):
if mm_config.emulate:
return torch.ops.aten.mm_float8_emulated(
a._data, a._scale, b._data, b._scale, output_dtype
)[0]
tensor_out, amax = addmm_float8_unwrapped(
)
tensor_out = addmm_float8_unwrapped(
a_data,
a_scale,
b_data,
Expand Down Expand Up @@ -180,9 +180,9 @@ def float8_addmm(aten_op, args, kwargs=None):
if mm_config.emulate:
out = torch.ops.aten.mm_float8_emulated(
a._data, a._scale, b._data, b._scale, output_dtype
)[0]
)
return out + bias
tensor_out, amax = addmm_float8_unwrapped(
tensor_out = addmm_float8_unwrapped(
a_data,
a_scale,
b_data,
Expand Down
18 changes: 9 additions & 9 deletions float8_experimental/float8_python_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
"""


from typing import Optional, Tuple
from typing import Optional

import float8_experimental.float8_aten_api # noqa

Expand All @@ -31,7 +31,7 @@ def addmm_float8_unwrapped(
output_scale: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
use_fast_accum: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> torch.Tensor:
"""
This is the unwrapped version of addmm_float8, which does not take in Float8Tensors
as inputs. This is used to standardize the logic between subclassed and non subclassed
Expand All @@ -41,25 +41,25 @@ def addmm_float8_unwrapped(
b_inverse_scale = b_scale.reciprocal()
if output_dtype == torch.float32 and bias is not None:
# Bias is not supported by _scaled_mm when output is fp32
output, output_amax = torch._scaled_mm(
output = torch._scaled_mm(
a_data,
b_data,
out_dtype=output_dtype,
scale_a=a_inverse_scale,
scale_b=b_inverse_scale,
scale_result=output_scale,
out_dtype=output_dtype,
use_fast_accum=use_fast_accum,
)
output += bias
return output, output_amax
output, output_amax = torch._scaled_mm(
return output
output = torch._scaled_mm(
a_data,
b_data,
bias=bias,
out_dtype=output_dtype,
scale_a=a_inverse_scale,
scale_b=b_inverse_scale,
bias=bias,
scale_result=output_scale,
out_dtype=output_dtype,
use_fast_accum=use_fast_accum,
)
return output, output_amax
return output
12 changes: 2 additions & 10 deletions test/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
ScaledMMConfig,
)
from float8_experimental.float8_utils import (
amax_to_scale,
compute_error,
fp8_tensor_statistics,
FP8_TYPES,
Expand Down Expand Up @@ -327,29 +326,22 @@ def test_scaled_mm_vs_emulated(self, base_dtype, use_fast_accum):
a_fp8 = Float8Tensor.to_float8(a, a_scale, input_dtype)
b_fp8 = Float8Tensor.to_float8(b, b_scale, input_dtype)

out_scaled_mm, output_amax_scaled = addmm_float8_unwrapped(
out_scaled_mm = addmm_float8_unwrapped(
a_fp8._data,
a_fp8._scale,
b_fp8._data,
b_fp8._scale,
output_dtype=output_dtype,
use_fast_accum=use_fast_accum,
)
out_emulated, output_amax_emulated = torch.ops.aten.mm_float8_emulated(
out_emulated = torch.ops.aten.mm_float8_emulated(
a_fp8._data, a_fp8._scale, b_fp8._data, b_fp8._scale, output_dtype
)

if output_dtype != base_dtype:
out_scaled_mm = out_scaled_mm.to(compare_type)
out_emulated = out_emulated.to(compare_type)

out_scaled_mm = out_scaled_mm / amax_to_scale(
output_amax_scaled, input_dtype
)
out_emulated = out_emulated / amax_to_scale(
output_amax_emulated, input_dtype
)

if base_dtype in {torch.bfloat16, torch.float16}:
atol, rtol = 7e-2, 7e-2
else:
Expand Down
Loading