Skip to content

Commit

Permalink
Implement FPX quantisation (#19)
Browse files Browse the repository at this point in the history
* Implement FPX quantisation

* Test E5M2, and check exponent_bits and mantissa_bits range

* Add comment about torch.flip
  • Loading branch information
DouglasOrr committed Jul 8, 2023
1 parent 899aec4 commit beb1267
Show file tree
Hide file tree
Showing 6 changed files with 601 additions and 33 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ OBJDIR ?= $(dir $(OUT))obj
CXXFLAGS = -Wall -Wextra -Werror -std=c++17 -O2 -g -fPIC -DONNX_NAMESPACE=onnx
LIBS = -lpoplar -lpopart -lpopops -lpopsparse -lpoputil -lgcl

OBJECTS = $(OBJDIR)/static_spmm.o $(OBJDIR)/autograd_proxy.o $(OBJDIR)/replicatedallreducetp.o $(OBJDIR)/distance_matrix.o $(OBJDIR)/replicatedalltoall.o
OBJECTS = $(OBJDIR)/static_spmm.o $(OBJDIR)/autograd_proxy.o $(OBJDIR)/replicatedallreducetp.o $(OBJDIR)/distance_matrix.o $(OBJDIR)/replicatedalltoall.o $(OBJDIR)/simulated_quant.o

# Rules

Expand Down
6 changes: 4 additions & 2 deletions dev
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,10 @@ def build() -> None:
run(["make", "-j"])


@cli("-s", "--no-capture", action="store_false", dest="capture")
@cli("-k", "--filter")
@cli("--gdb", action="store_true")
def tests(filter: Optional[str], gdb: bool) -> None:
def tests(capture: bool, filter: Optional[str], gdb: bool) -> None:
"""run Python tests"""
build()
run(
Expand All @@ -67,6 +68,7 @@ def tests(filter: Optional[str], gdb: bool) -> None:
"tests",
None if filter else "--cov=poptorch_experimental_addons",
*(["-k", filter] if filter else []),
None if capture else "-s",
],
gdb=gdb,
)
Expand Down Expand Up @@ -158,7 +160,7 @@ def doc() -> None:
def ci(skip: List[str] = []) -> None:
"""run all continuous integration tests & checks"""
if "tests" not in skip:
tests(filter=None, gdb=False)
tests(capture=True, filter=None, gdb=False)
if "lint" not in skip:
lint()
if "format" not in skip:
Expand Down
168 changes: 165 additions & 3 deletions poptorch_experimental_addons/_impl/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
"""

from pathlib import Path
from typing import Any, Tuple
from typing import Any, Optional, Tuple

import poptorch
import torch
Expand Down Expand Up @@ -102,12 +102,174 @@ def distance_matrix(tensor1: Tensor, tensor2: Tensor, p: int) -> Tensor:
device=tensor1.device,
)
],
attributes=dict(root_path=str(Path(__file__).parent.parent)),
attributes=dict(root_path=str(Path(__file__).parent.parent.absolute())),
)
else:
y = torch.cdist(tensor1, tensor2, p=p)

return y


__all__ = ["autograd_proxy", "distance_matrix"]
def quantise_fpx(
x: Tensor,
exponent_bits: int,
mantissa_bits: int,
rounding: str = "stochastic",
fwd: bool = True,
bwd: Optional[bool] = None,
) -> Tensor:
"""
Quantise the values in a tensor to a lower-precision floating point format.
Note that this is not a cast; the returned tensor has the same dtype as the input.
quantise_fpx(tensor(0.2), exponent_bits=2, mantissa_bits=1, rounding="nearest")
=> 0.25
By default, quantise in the forward pass and return no gradient.
exponent_bits, mantissa_bits -- define the FP format (total bits = 1 (sign) + E + M)
rounding -- either "nearest" or "stochastic"
fwd -- whether to quantise the forward value
bwd -- whether to generate & whether to quantise the gradient:
bwd=None -- no gradient
bwd=False -- unquantised gradient (straight-through estimator)
bwd=True -- quantised gradient
"""
if rounding not in ["nearest", "stochastic"]:
raise ValueError(
"Expected quantise(rounding=?) to be 'nearest' or 'stochastic'"
f", actual '{rounding}'"
)

if poptorch.isRunningOnIpu():
max_exponent_bits = 5
max_mantissa_bits = 10
else:
max_exponent_bits = 8
max_mantissa_bits = 23
if exponent_bits > max_exponent_bits:
raise ValueError(
f"quantise_fpx(exponent_bits={exponent_bits}) not supported, maximum"
f" number of exponent bits is {max_exponent_bits}"
)
if mantissa_bits > max_mantissa_bits:
raise ValueError(
f"quantise_fpx(mantissa_bits={mantissa_bits}) not supported, maximum"
f" number of mantissa bits is {max_mantissa_bits}"
)

q: Tensor
if poptorch.isRunningOnIpu():
(q,) = poptorch.custom_op(
name="SimulatedQuant",
domain_version=1,
domain="ai.graphcore",
inputs=[x],
example_outputs=[x],
attributes=dict(
root_path=str(Path(__file__).parent.parent.absolute()),
exponent_bits=exponent_bits,
mantissa_bits=mantissa_bits,
rounding=rounding,
fwd=fwd,
bwd={True: "quantise", False: "ste", None: "stop"}[bwd],
),
)
return q

def _quantise(x: Tensor) -> Tensor:
max_exponent = 2 ** (exponent_bits - 1) - 1
absmax = 2**max_exponent * (2 - 2**-mantissa_bits)
downscale = 2.0 ** (126 - max_exponent)
mask = torch.tensor(
2 ** (23 - mantissa_bits) - 1, dtype=torch.int32, device=x.device
)
offset = (
torch.randint( # type:ignore[call-overload]
0, mask + 1, x.shape, dtype=torch.int32, device=x.device
)
if rounding == "stochastic"
else mask // 2
)
# Manually clip to max
# Then scale down (to generate appropriate subnormals) & mask off mantissa bits
q = x.to(torch.float32)
q = torch.clip(x, -absmax, absmax)
q /= downscale
q = ((q.view(torch.int32) + offset) & ~mask).view(torch.float32)
q *= downscale
q = q.to(x.dtype)
return q

class F(torch.autograd.Function):
@staticmethod
def forward( # type:ignore[override]
ctx: torch.autograd.function.FunctionCtx, xx: Tensor
) -> Tensor:
return _quantise(xx) if fwd else xx.clone()

@staticmethod
def backward( # type:ignore[override]
ctx: torch.autograd.function.FunctionCtx, grad_y: Tensor
) -> Optional[Tensor]:
if bwd is not None:
return _quantise(grad_y) if bwd else grad_y
return None

q = F.apply(x) # type:ignore[no-untyped-call]
return q


def quantise_fpx_ste(
x: Tensor,
exponent_bits: int,
mantissa_bits: int,
rounding: str = "stochastic",
) -> Tensor:
"""
Quantise the forward value while leaving the gradient unchanged, as a
straight-through estimator.
See `quantise_fpx` for more detail.
"""
return quantise_fpx(
x,
exponent_bits=exponent_bits,
mantissa_bits=mantissa_bits,
rounding=rounding,
fwd=True,
bwd=False,
)


def quantise_fpx_grad(
x: Tensor,
exponent_bits: int,
mantissa_bits: int,
rounding: str = "stochastic",
) -> Tensor:
"""
Quantise the gradient while leaving the forward value unchanged.
See `quantise_fpx` for more detail.
"""
return quantise_fpx(
x,
exponent_bits=exponent_bits,
mantissa_bits=mantissa_bits,
rounding=rounding,
fwd=False,
bwd=True,
)


__all__ = [
"autograd_proxy",
"distance_matrix",
"quantise_fpx",
"quantise_fpx_ste",
"quantise_fpx_grad",
]
Loading

0 comments on commit beb1267

Please sign in to comment.