Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds float8 support back to nanoGPT #446

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions tripy/examples/nanogpt/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -80,3 +80,9 @@ To run with a quantization mode, pass `--quant-mode` to `example.py`. The suppor
```
Tripy: TEST: EXPECTED_STDOUT End
-->

3. float8 quantization:

```bash
python3 example.py --input-text "What is the answer to life, the universe, and everything?" --seed=0 --quant-mode float8
```
2 changes: 1 addition & 1 deletion tripy/examples/nanogpt/example.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def main():
"--quant-mode",
type=str,
help="Quantization mode.",
choices=["int8-weight-only", "int4-weight-only"],
choices=["int8-weight-only", "int4-weight-only", "float8"],
)

args = parser.parse_args()
Expand Down
13 changes: 10 additions & 3 deletions tripy/examples/nanogpt/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ def linear_layer(config: GPTConfig, in_feat, out_feat, bias):
elif config.quant_mode == "int4-weight-only":
quant_kwargs["quant_dtype"] = tp.int4
quant_kwargs["weight_quant_dim"] = None
elif config.quant_mode == "float8":
quant_kwargs["quant_dtype"] = tp.float8
quant_kwargs["weight_quant_dim"] = None

return tp.Linear(
in_feat,
Expand Down Expand Up @@ -73,7 +76,7 @@ def __call__(self, x: tp.Tensor):
qkv = self.c_attn(x) # (batch_size, seq_len, 3 * embedding_size)

# WAR for better accuracy and avoid TRT compilation error in fp16
if self.c_attn.quant_dtype == tp.int4:
if self.c_attn.quant_dtype in (tp.float8, tp.int4):
qkv = tp.cast(qkv, tp.float32)

q, k, v = tp.split(qkv, 3, dim=2)
Expand Down Expand Up @@ -156,8 +159,12 @@ def __init__(self, config):
), f"Cannot forward sequence of length {config.seq_len}, block size is only {config.block_size}"

self.transformer = Transformer(config)
# Quantization is disabled for `lm_head`
self.lm_head = tp.Linear(config.embedding_size, config.vocab_size, bias=False, dtype=config.dtype)

if config.quant_mode == "float8":
self.lm_head = linear_layer(config, config.embedding_size, config.vocab_size, bias=False)
else:
# Quantization is disabled for `lm_head` except for FP8.
self.lm_head = tp.Linear(config.embedding_size, config.vocab_size, bias=False, dtype=config.dtype)

def __call__(self, idx):
x = self.transformer(idx)
Expand Down
2 changes: 2 additions & 0 deletions tripy/examples/nanogpt/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ def modelopt_quantize(model_hf, quant_mode):
}
elif quant_mode == "int4-weight-only":
quant_cfg = mtq.INT4_AWQ_CFG
elif quant_mode == "float8":
quant_cfg = mtq.FP8_DEFAULT_CFG
else:
raise NotImplementedError(f"Unsupported quantization mode: {quant_mode}")

Expand Down
2 changes: 1 addition & 1 deletion tripy/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from tripy.common.datatype import DATA_TYPES

skip_if_older_than_sm89 = pytest.mark.skipif(
torch.cuda.get_device_capability() < (8, 9), reason="Some features (e.g. fp8) are not available before SM90"
torch.cuda.get_device_capability() < (8, 9), reason="Some features (e.g. float8) are not available before SM90"
)

skip_if_older_than_sm80 = pytest.mark.skipif(
Expand Down
6 changes: 3 additions & 3 deletions tripy/tests/integration/test_dequantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,12 @@ def func(input):
output = torch.from_dlpack(dequantized)
assert torch.allclose(expected, output.to("cpu"))

# TODO(#161): Update fp8 test to use frontend representation
# TODO(#161): Update float8 test to use frontend representation
@pytest.mark.parametrize(
"dtype", [tp.float32, tp.float16, pytest.param(tp.bfloat16, marks=skip_if_older_than_sm80)]
)
@skip_if_older_than_sm89
def test_dequantize_fp8_per_tensor(self, dtype):
def test_dequantize_float8_per_tensor(self, dtype):
data_value = [1.0, 1.0]
input_tp = tp.Tensor(data_value, dtype=tp.float8)
scale = torch.tensor(0.5, dtype=TORCH_DTYPES[dtype])
Expand All @@ -84,7 +84,7 @@ def test_dequantize_fp8_per_tensor(self, dtype):
"dtype", [tp.float32, tp.float16, pytest.param(tp.bfloat16, marks=skip_if_older_than_sm80)]
)
@skip_if_older_than_sm89
def test_dequantize_fp8_per_channel(self, dtype):
def test_dequantize_float8_per_channel(self, dtype):
data_value = [[1.0, 1.0], [1.0, 1.0]]
input_tp = tp.Tensor(data_value, dtype=tp.float8)
scale = torch.tensor([0.8, 0.9], dtype=TORCH_DTYPES[dtype])
Expand Down
4 changes: 2 additions & 2 deletions tripy/tests/integration/test_quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def func(input):
"dtype", [tp.float32, tp.float16, pytest.param(tp.bfloat16, marks=skip_if_older_than_sm80)]
)
@skip_if_older_than_sm89
def test_quantize_fp8_per_tensor(self, dtype, eager_or_compiled):
def test_quantize_float8_per_tensor(self, dtype, eager_or_compiled):
input = torch.tensor([1.0, 2.0], dtype=TORCH_DTYPES[dtype])
scale = torch.tensor(0.5, dtype=TORCH_DTYPES[dtype])
input_tp = tp.Tensor(input, dtype=dtype)
Expand All @@ -96,7 +96,7 @@ def func(input):
"dtype", [tp.float32, tp.float16, pytest.param(tp.bfloat16, marks=skip_if_older_than_sm80)]
)
@skip_if_older_than_sm89
def test_quantize_fp8_per_channel(self, dtype, eager_or_compiled):
def test_quantize_float8_per_channel(self, dtype, eager_or_compiled):
input = torch.tensor([[1.0, 2.0], [3.0, 4.0]], dtype=TORCH_DTYPES[dtype])
scale = torch.tensor([0.2, 0.1], dtype=TORCH_DTYPES[dtype])
input_tp = tp.Tensor(input, dtype=dtype)
Expand Down