Skip to content

Commit 2358c2b

Browse files
authored
Add basic FP8 KV cache support (#2603)
* Add basic FP8 KV cache support This change adds rudimentary FP8 KV cache support. The support is enabled by passing `--kv-cache-dtype fp8_e5m2` to the launcher. Doing so uses this type for the KV cache. However support is still limited: * Only the `fp8_e5m2` type is supported. * The KV cache layout is the same as `float16`/`bfloat16` (HND). * The FP8 KV cache is only supported for FlashInfer. * Loading of scales is not yet supported. * Fix Cargo.toml
1 parent 6810307 commit 2358c2b

33 files changed

+1015
-192
lines changed

Cargo.lock

Lines changed: 7 additions & 7 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

docs/source/reference/launcher.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,15 @@ Options:
8989
[env: DTYPE=]
9090
[possible values: float16, bfloat16]
9191

92+
```
93+
## KV_CACHE_DTYPE
94+
```shell
95+
--kv-cache-dtype <KV_CACHE_DTYPE>
96+
Specify the dtype for the key-value cache. When this option is not provided, the dtype of the model is used (typically `float16` or `bfloat16`). Currently the only supported value is `fp8_e5m2` on CUDA
97+
98+
[env: KV_CACHE_DTYPE=]
99+
[possible values: fp8_e5m2]
100+
92101
```
93102
## TRUST_REMOTE_CODE
94103
```shell

integration-tests/conftest.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -336,6 +336,7 @@ def local_launcher(
336336
use_flash_attention: bool = True,
337337
disable_grammar_support: bool = False,
338338
dtype: Optional[str] = None,
339+
kv_cache_dtype: Optional[str] = None,
339340
revision: Optional[str] = None,
340341
max_input_length: Optional[int] = None,
341342
max_batch_prefill_tokens: Optional[int] = None,
@@ -375,6 +376,9 @@ def local_launcher(
375376
if dtype is not None:
376377
args.append("--dtype")
377378
args.append(dtype)
379+
if kv_cache_dtype is not None:
380+
args.append("--kv-cache-dtype")
381+
args.append(kv_cache_dtype)
378382
if revision is not None:
379383
args.append("--revision")
380384
args.append(revision)
@@ -434,6 +438,7 @@ def docker_launcher(
434438
use_flash_attention: bool = True,
435439
disable_grammar_support: bool = False,
436440
dtype: Optional[str] = None,
441+
kv_cache_dtype: Optional[str] = None,
437442
revision: Optional[str] = None,
438443
max_input_length: Optional[int] = None,
439444
max_batch_prefill_tokens: Optional[int] = None,
@@ -456,6 +461,9 @@ def docker_launcher(
456461
if dtype is not None:
457462
args.append("--dtype")
458463
args.append(dtype)
464+
if kv_cache_dtype is not None:
465+
args.append("--kv-cache-dtype")
466+
args.append(kv_cache_dtype)
459467
if revision is not None:
460468
args.append("--revision")
461469
args.append(revision)
@@ -589,7 +597,6 @@ async def generate_load_inner(
589597
max_new_tokens: int,
590598
seed: Optional[int] = None,
591599
) -> List[Response]:
592-
593600
import numpy as np
594601

595602
arange = np.arange(len(prompts))
Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
{
2+
"details": {
3+
"best_of_sequences": null,
4+
"finish_reason": "length",
5+
"generated_tokens": 10,
6+
"prefill": [
7+
{
8+
"id": 128000,
9+
"logprob": null,
10+
"text": "<|begin_of_text|>"
11+
},
12+
{
13+
"id": 3923,
14+
"logprob": -5.6328125,
15+
"text": "What"
16+
},
17+
{
18+
"id": 374,
19+
"logprob": -1.2265625,
20+
"text": " is"
21+
},
22+
{
23+
"id": 5655,
24+
"logprob": -9.1015625,
25+
"text": " deep"
26+
},
27+
{
28+
"id": 6975,
29+
"logprob": -1.8085938,
30+
"text": " learning"
31+
},
32+
{
33+
"id": 30,
34+
"logprob": -1.0439453,
35+
"text": "?"
36+
}
37+
],
38+
"seed": null,
39+
"tokens": [
40+
{
41+
"id": 18682,
42+
"logprob": -2.1992188,
43+
"special": false,
44+
"text": " Deep"
45+
},
46+
{
47+
"id": 6975,
48+
"logprob": -0.079956055,
49+
"special": false,
50+
"text": " learning"
51+
},
52+
{
53+
"id": 374,
54+
"logprob": -0.2763672,
55+
"special": false,
56+
"text": " is"
57+
},
58+
{
59+
"id": 264,
60+
"logprob": -0.37548828,
61+
"special": false,
62+
"text": " a"
63+
},
64+
{
65+
"id": 27084,
66+
"logprob": -1.4628906,
67+
"special": false,
68+
"text": " subset"
69+
},
70+
{
71+
"id": 315,
72+
"logprob": -0.02885437,
73+
"special": false,
74+
"text": " of"
75+
},
76+
{
77+
"id": 5780,
78+
"logprob": -0.2565918,
79+
"special": false,
80+
"text": " machine"
81+
},
82+
{
83+
"id": 6975,
84+
"logprob": -0.0063438416,
85+
"special": false,
86+
"text": " learning"
87+
},
88+
{
89+
"id": 430,
90+
"logprob": -1.3056641,
91+
"special": false,
92+
"text": " that"
93+
},
94+
{
95+
"id": 374,
96+
"logprob": -1.6035156,
97+
"special": false,
98+
"text": " is"
99+
}
100+
],
101+
"top_tokens": null
102+
},
103+
"generated_text": " Deep learning is a subset of machine learning that is"
104+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
{
2+
"details": {
3+
"best_of_sequences": null,
4+
"finish_reason": "eos_token",
5+
"generated_tokens": 3,
6+
"prefill": [
7+
{
8+
"id": 128000,
9+
"logprob": null,
10+
"text": "<|begin_of_text|>"
11+
},
12+
{
13+
"id": 374,
14+
"logprob": -22.96875,
15+
"text": " is"
16+
},
17+
{
18+
"id": 5655,
19+
"logprob": -10.71875,
20+
"text": " deep"
21+
},
22+
{
23+
"id": 6975,
24+
"logprob": -2.6992188,
25+
"text": " learning"
26+
},
27+
{
28+
"id": 30,
29+
"logprob": -4.8398438,
30+
"text": "?"
31+
}
32+
],
33+
"seed": 0,
34+
"tokens": [
35+
{
36+
"id": 720,
37+
"logprob": -0.4411621,
38+
"special": false,
39+
"text": " \n"
40+
},
41+
{
42+
"id": 220,
43+
"logprob": -0.35864258,
44+
"special": false,
45+
"text": " "
46+
},
47+
{
48+
"id": 128001,
49+
"logprob": 0.0,
50+
"special": true,
51+
"text": "<|end_of_text|>"
52+
}
53+
],
54+
"top_tokens": null
55+
},
56+
"generated_text": "What is deep learning? \n "
57+
}

0 commit comments

Comments
 (0)