Skip to content

Commit

Permalink
Merge pull request #56 from graphcore-research/stochastic-rounding-nu…
Browse files Browse the repository at this point in the history
…m-bits

Parametrize the number of random bits used in stochastic rounding
  • Loading branch information
awf committed Jun 27, 2024
2 parents bd72009 + 3457e84 commit a5b595c
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 9 deletions.
19 changes: 17 additions & 2 deletions unit_scaling/formats.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,15 @@ class FPFormat:
exponent_bits: int
mantissa_bits: int
rounding: str = "stochastic" # "stochastic|nearest"
srbits: int = 0 # Number of bits for stochastic rounding, zero => use all bits

def __post_init__(self) -> None:
assert self.exponent_bits >= 2, "FPFormat requires at least 2 exponent bits"
assert (
self.srbits == 0 or self.rounding == "stochastic"
), "Nonzero srbits for non-stochastic rounding"
if self.srbits == 0 and self.rounding == "stochastic":
self.srbits = 23 - self.mantissa_bits

@property
def bits(self) -> int:
Expand Down Expand Up @@ -74,9 +80,18 @@ def quantise(self, x: Tensor) -> Tensor:
downscale = 2.0 ** (127 - 2 ** (self.exponent_bits - 1))
mask = torch.tensor(2 ** (23 - self.mantissa_bits) - 1, device=x.device)
if self.rounding == "stochastic":
offset = torch.randint( # type: ignore[call-overload]
0, mask + 1, x.shape, dtype=torch.int32, device=x.device
srbitsbar = 23 - self.mantissa_bits - self.srbits
offset = (
torch.randint(
0, 2**self.srbits, x.shape, dtype=torch.int32, device=x.device
)
<< srbitsbar
)
# Correct for bias. We can do this only for srbits < 23-mantissa_bits,
# but it is only likely to matter when srbits is small.
if srbitsbar > 0:
offset += 1 << (srbitsbar - 1)

elif self.rounding == "nearest":
offset = mask // 2
else: # pragma: no cover
Expand Down
16 changes: 9 additions & 7 deletions unit_scaling/tests/test_formats.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,15 @@ def test_fp_format_rounding() -> None:
y_nearest = FPFormat(2, 1, rounding="nearest").quantise(torch.full((n,), x))
assert collections.Counter(y_nearest.tolist()) == {-1.5: n}

y_stochastic = FPFormat(2, 1, rounding="stochastic").quantise(torch.full((n,), x))
count = collections.Counter(y_stochastic.tolist())
assert count.keys() == {-1.5, -1.0}
expected_ratio = (1.35 - 1.0) / 0.5
nearest_ratio = count[-1.5] / sum(count.values())
std_x3 = 3 * (expected_ratio * (1 - expected_ratio) / n) ** 0.5
assert expected_ratio - std_x3 < nearest_ratio < expected_ratio + std_x3
for srbits in (0, 13):
srformat = FPFormat(2, 1, rounding="stochastic", srbits=srbits)
y_stochastic = srformat.quantise(torch.full((n,), x))
count = collections.Counter(y_stochastic.tolist())
assert count.keys() == {-1.5, -1.0}
expected_ratio = (1.35 - 1.0) / 0.5
nearest_ratio = count[-1.5] / sum(count.values())
std_x3 = 3 * (expected_ratio * (1 - expected_ratio) / n) ** 0.5
assert expected_ratio - std_x3 < nearest_ratio < expected_ratio + std_x3


def test_fp_format_bwd() -> None:
Expand Down

0 comments on commit a5b595c

Please sign in to comment.