-
Notifications
You must be signed in to change notification settings - Fork 20
Changes on top of upstream to get rid of type errors #248
Conversation
@drisspg :) |
Awesome! I will take a look at this tomorrow |
float8_experimental/float8_utils.py
Outdated
@@ -28,12 +28,25 @@ | |||
IS_AMD = torch.cuda.is_available() and torch.version.hip is not None | |||
|
|||
|
|||
# Helper functions to get individual F8 types based on backend architecture |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would it be possible to put this into configuration instead of setting it dynamically? It can be unexpected for numerics to change based on the environment. It would also be good to support numerical emulation of all of these types regardless of whether the user's machine supports a float8 matmul.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm afraid I don't understand your question. These helper functions are simply intended to grab the "right" version of the prebuilt torch F8 types. Could you elaborate on the change you'd like to see?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, it's just making the dtype flavors encoded in configuration instead of environment dependent. Having this in configuration would make it easier to debug numerics without having the target hardware.
# float8 dtypes have a default which can be changed explicitly
config = ...
config.float8_flavors = 'nuz'
do_float8_things(..., config)
versus
# float8 dtypes magically change based on the environment
do_float8_things(...)
That said, my comment is not high pri, feel free to land and we can adjust this later if it becomes important.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@vkuzo Sorry, got wrapped up in other work recently and just circled back to this. Okay, I will add an option in float8_experimental/float8_experimental/config.py and instead of checking the backend architecture, the code will check against this user-settable config variable
test/test_base.py
Outdated
@@ -350,7 +374,7 @@ def test_scaled_mm_vs_emulated(self, base_dtype): | |||
|
|||
|
|||
class TestNumerics: | |||
@pytest.mark.parametrize("float8_dtype", [torch.float8_e4m3fn, torch.float8_e5m2]) | |||
@pytest.mark.parametrize("float8_dtype", [fp8_e4m3_t(), fp8_e5m2_t()]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would recommend testing all cases on all hardware types instead. For things not requiring a matmul, it should just work. For things requiring a matmul, we have an emulation mode to at least help approximate it.
test/test_base.py
Outdated
@@ -47,7 +51,10 @@ class TestFloat8Tensor(unittest.TestCase): | |||
def test_preserves_dtype(self) -> None: | |||
# hp means high precision, lp means low precision | |||
hp_dtypes = (torch.float32, torch.float16, torch.bfloat16) | |||
lp_dtypes = (torch.float8_e4m3fn, torch.float8_e5m2) | |||
fp8_dtypes = ( | |||
FP8Dtypes() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
all dtypes would be nice, there should not be anything in Float8Tensor
which is hardware dependent
test/test_base.py
Outdated
m = nn.Linear(32, 16, device="cuda", dtype=linear_dtype) | ||
m = get_float8_linear(linear_type, m, emulate, False) | ||
m = get_float8_linear(linear_type, m, emulate, False, fp8_dtypes) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you can enable emulation here if your hardware doesn't support the dtype under test
@vkuzo Ready for another review. Also, wanted to ask if there was an ETA or roadmap for when this functionality would be pulled into pytorch proper? |
so sorry I am on holiday right now, will take a look late next week when I return, unless @drisspg wants to get to it sooner |
Yeah will review tomorrow |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What was the output of test_everything.sh?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What was the output of test_everything.sh?
Fails in test_compile. However, I was aware of this failure and found that this is unrelated to my changes, rather, an issue with torch.compile on ROCm. This failure was next on my TODO's to address. I will upload a follow up PR. |
float8_experimental/config.py
Outdated
# If True, use 'fnuz' float8 types for calculations. If the backend | ||
# hardware does not support a particular type, the emulated implementation | ||
# of the dtype will be used. Currently, ROCm only supports the fnuz variants. | ||
use_fnuz_dtype = True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
default to False?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah, yes, that's an oversight due to being easier for me to test. will change.
test/test_dtensor.py
Outdated
@@ -128,7 +128,7 @@ def test_dtensor_fp8_autograd(mesh: DeviceMesh, size=16): | |||
) | |||
|
|||
out = torch.nn.functional.linear(dist_x_fp8, dist_weight_fp8) | |||
out = NoopFwToFloat8E5M2Bw.apply(out, False) | |||
out = NoopFwToFloat8E5M2Bw.apply(out, False, fp8_e5m2_t()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is the new last arg here expected?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oversight, fixed in latest commit
e29cc35
to
ba9b5dd
Compare
5da5b5c
to
4f304cc
Compare
just an update here, the base PR should landed yesterday |
57120fa
to
4997c19
Compare
float8_experimental/config.py
Outdated
@@ -19,3 +19,8 @@ | |||
# implements pre/post-all-gather methods to do fp8 all-gather with FSDP2. | |||
# Only dynamic scaling is supported for now. | |||
enable_fsdp_fp8_all_gather = False | |||
|
|||
# If True, use 'fnuz' float8 types for calculations. If the backend | |||
# hardware does not support a particular dtype, the emulated implementation |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: currently the user is responsible for toggling the emulation setting, we don't do that automatically
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought per previous comments that we wanted to go the emulated route if the backend hardware didn't support the type? in general, the user could force to emulate, but i was under the impression that cases where a dtype was being used that wasn't supported on the underlying hardware, that we wanted to go the emulated route?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yep, that is correct! This is just currently done by the user explicitly, and there is no support do handle that automatically.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh, you're saying the comment is wrong
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks great! thanks for helping!
@drisspg has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
@drisspg fixed lint. But not sure what the ufmt errors are about. |
@alugorey I think if you apply this patch it should work: https://gist.github.com/drisspg/2a87d54521a0b2312ac44d070f63350d |
@drisspg Looks like you beat me to it. Still failing on 2 files though. is there documentation on what ufmt expects? some of those changes seem purely cosmetic |
The reason for this formatting is due to some internal nodes on code styling. TBH I have also have some 'pre-commit' hooks for the repo that should be set and forget. There does seem to be 2 more lint fixes. There seems to be 1 real test_failure:
I found this patch passed test_everything: https://gist.github.com/drisspg/47a29d6bf3fcca2a2c48d09b74c564aa |
@drisspg has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
Fixes the class of failed unit tests on rocm in test_base.py that fail the internal assertion
Cannot convert ScalarType Float8_e4m3fn to hipDataType.
Note: We are aware of the outstanding numerical issues and are looking into it internally.