Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit 03c9119

Browse files
committed
Update on "[3/x] clean up casting functions: delete to_fp8_no_autograd"
Summary: `ToFloat8ConstrFunc` was just calling `to_fp8_no_autograd`, unify them to reduce confusion. We can rename the function in a future PR, keeping PRs small for now. Test Plan: ``` ./test/test_everything.sh ``` Reviewers: Subscribers: Tasks: Tags: Differential Revision: [D60292694](https://our.internmc.facebook.com/intern/diff/D60292694) [ghstack-poisoned]
1 parent e130144 commit 03c9119

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

float8_experimental/float8_tensor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ def forward(
169169

170170
if isinstance(bits_fp8, DTensor):
171171
assert isinstance(
172-
x, DTensor
172+
scale, DTensor
173173
), "Expected Float8 scale to be a DTensor if bits_fp8 is a DTensor"
174174
bits_mesh = bits_fp8.device_mesh
175175
bits_placements = bits_fp8.placements

0 commit comments

Comments
 (0)