Skip to content

Commit

Permalink
Add basic FP8 KV cache support (#2603)
Browse files Browse the repository at this point in the history
* Add basic FP8 KV cache support

This change adds rudimentary FP8 KV cache support. The support is
enabled by passing `--kv-cache-dtype fp8_e5m2` to the launcher. Doing so
uses this type for the KV cache. However support is still limited:

* Only the `fp8_e5m2` type is supported.
* The KV cache layout is the same as `float16`/`bfloat16` (HND).
* The FP8 KV cache is only supported for FlashInfer.
* Loading of scales is not yet supported.

* Fix Cargo.toml
  • Loading branch information
danieldk authored Oct 4, 2024
1 parent 6810307 commit 2358c2b
Show file tree
Hide file tree
Showing 33 changed files with 1,015 additions and 192 deletions.
14 changes: 7 additions & 7 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

9 changes: 9 additions & 0 deletions docs/source/reference/launcher.md
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,15 @@ Options:
[env: DTYPE=]
[possible values: float16, bfloat16]

```
## KV_CACHE_DTYPE
```shell
--kv-cache-dtype <KV_CACHE_DTYPE>
Specify the dtype for the key-value cache. When this option is not provided, the dtype of the model is used (typically `float16` or `bfloat16`). Currently the only supported value is `fp8_e5m2` on CUDA

[env: KV_CACHE_DTYPE=]
[possible values: fp8_e5m2]

```
## TRUST_REMOTE_CODE
```shell
Expand Down
9 changes: 8 additions & 1 deletion integration-tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,7 @@ def local_launcher(
use_flash_attention: bool = True,
disable_grammar_support: bool = False,
dtype: Optional[str] = None,
kv_cache_dtype: Optional[str] = None,
revision: Optional[str] = None,
max_input_length: Optional[int] = None,
max_batch_prefill_tokens: Optional[int] = None,
Expand Down Expand Up @@ -375,6 +376,9 @@ def local_launcher(
if dtype is not None:
args.append("--dtype")
args.append(dtype)
if kv_cache_dtype is not None:
args.append("--kv-cache-dtype")
args.append(kv_cache_dtype)
if revision is not None:
args.append("--revision")
args.append(revision)
Expand Down Expand Up @@ -434,6 +438,7 @@ def docker_launcher(
use_flash_attention: bool = True,
disable_grammar_support: bool = False,
dtype: Optional[str] = None,
kv_cache_dtype: Optional[str] = None,
revision: Optional[str] = None,
max_input_length: Optional[int] = None,
max_batch_prefill_tokens: Optional[int] = None,
Expand All @@ -456,6 +461,9 @@ def docker_launcher(
if dtype is not None:
args.append("--dtype")
args.append(dtype)
if kv_cache_dtype is not None:
args.append("--kv-cache-dtype")
args.append(kv_cache_dtype)
if revision is not None:
args.append("--revision")
args.append(revision)
Expand Down Expand Up @@ -589,7 +597,6 @@ async def generate_load_inner(
max_new_tokens: int,
seed: Optional[int] = None,
) -> List[Response]:

import numpy as np

arange = np.arange(len(prompts))
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 128000,
"logprob": null,
"text": "<|begin_of_text|>"
},
{
"id": 3923,
"logprob": -5.6328125,
"text": "What"
},
{
"id": 374,
"logprob": -1.2265625,
"text": " is"
},
{
"id": 5655,
"logprob": -9.1015625,
"text": " deep"
},
{
"id": 6975,
"logprob": -1.8085938,
"text": " learning"
},
{
"id": 30,
"logprob": -1.0439453,
"text": "?"
}
],
"seed": null,
"tokens": [
{
"id": 18682,
"logprob": -2.1992188,
"special": false,
"text": " Deep"
},
{
"id": 6975,
"logprob": -0.079956055,
"special": false,
"text": " learning"
},
{
"id": 374,
"logprob": -0.2763672,
"special": false,
"text": " is"
},
{
"id": 264,
"logprob": -0.37548828,
"special": false,
"text": " a"
},
{
"id": 27084,
"logprob": -1.4628906,
"special": false,
"text": " subset"
},
{
"id": 315,
"logprob": -0.02885437,
"special": false,
"text": " of"
},
{
"id": 5780,
"logprob": -0.2565918,
"special": false,
"text": " machine"
},
{
"id": 6975,
"logprob": -0.0063438416,
"special": false,
"text": " learning"
},
{
"id": 430,
"logprob": -1.3056641,
"special": false,
"text": " that"
},
{
"id": 374,
"logprob": -1.6035156,
"special": false,
"text": " is"
}
],
"top_tokens": null
},
"generated_text": " Deep learning is a subset of machine learning that is"
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
{
"details": {
"best_of_sequences": null,
"finish_reason": "eos_token",
"generated_tokens": 3,
"prefill": [
{
"id": 128000,
"logprob": null,
"text": "<|begin_of_text|>"
},
{
"id": 374,
"logprob": -22.96875,
"text": " is"
},
{
"id": 5655,
"logprob": -10.71875,
"text": " deep"
},
{
"id": 6975,
"logprob": -2.6992188,
"text": " learning"
},
{
"id": 30,
"logprob": -4.8398438,
"text": "?"
}
],
"seed": 0,
"tokens": [
{
"id": 720,
"logprob": -0.4411621,
"special": false,
"text": " \n"
},
{
"id": 220,
"logprob": -0.35864258,
"special": false,
"text": " "
},
{
"id": 128001,
"logprob": 0.0,
"special": true,
"text": "<|end_of_text|>"
}
],
"top_tokens": null
},
"generated_text": "What is deep learning? \n "
}
Loading

0 comments on commit 2358c2b

Please sign in to comment.