Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement FPX quantisation #19

Merged
merged 3 commits into from
Jul 8, 2023
Merged

Implement FPX quantisation #19

merged 3 commits into from
Jul 8, 2023

Conversation

DouglasOrr
Copy link
Contributor

@DouglasOrr DouglasOrr commented Jul 7, 2023

Shouldn't be anything too surprising (as mentioned, I think the low end of E5M2 is a bit wrong on IPU, but probably close enough for our purposes & irrelevant if FP8/16 are used together, in any case).

Unfortunately the IPU code doesn't work on IPUModel, so the (CPU) GitHub CI tests are a bit limited.

TODO: I need to put some range assertions in the IPU code for exponent & mantissa.

Copy link
Collaborator

@thecharlieblake thecharlieblake left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All looks good to me! 🚢 Happy with the design choices. Bring on IPU unit scaling!

tests/test_core.py Show resolved Hide resolved
}

x = torch.linspace(-1e5, 1e5, steps=1000)
grad_y = torch.flip(x, (0,))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does flip do anything particular here, or just a convenient way to have different grads from inputs?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a convenient way to have different grads. I'll pop a note in.


def assert_quantised(t: Tensor) -> None:
assert len(set(t.tolist())) <= 256
torch.testing.assert_close(t.max(), torch.tensor(57344.0))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the rtol here for fp32 will be 1.3e-6, which doesn't seem enough if you assume you might occasionally get less than the max value (maybe you won't though...)

A quick test:

>>> x = torch.tensor(57344.0)
>>> y = torch.tensor(49152.)
>>> assert_close(x, y)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/Users/charlieb/Projects/graphcore/unit-scaling/.venv/lib/python3.10/site-packages/torch/testing/_comparison.py", line 1511, in assert_close
    raise error_metas[0].to_error(msg)
AssertionError: Scalars are not close!

Absolute difference: 8192.0 (up to 1e-05 allowed)
Relative difference: 0.16666666666666666 (up to 1.3e-06 allowed)

I'd be tempted to set an atol of maybe 2**14 ish?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I should always get the max, since x = torch.linspace(-1e5, 1e5, steps=1000) in this test case. I'm using nearest rounding here, but I think the same should be true if it were stochastic(?)

poptorch_experimental_addons/_impl/core.py Show resolved Hide resolved
poptorch_experimental_addons/_impl/core.py Show resolved Hide resolved
def forward( # type:ignore[override]
ctx: torch.autograd.function.FunctionCtx, xx: Tensor
) -> Tensor:
return _quantise(xx) if fwd else xx.clone()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like I've come across this before, but is there a reason you need to clone in fwd but not bwd?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, it's a bit of a mystery, I recall it was quite a complex error message when I didn't clone() here. Perhaps double-differentiation would require a clone() in bwd too(?)

@DouglasOrr DouglasOrr merged commit beb1267 into main Jul 8, 2023
1 check passed
@DouglasOrr DouglasOrr deleted the quantise-fpx branch July 8, 2023 19:16
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants