Skip to content

Commit ac19f97

Browse files
committed
Pull upstream changes, fix conflict, bump version to 0.0.4
2 parents 82fa31f + 3e7c410 commit ac19f97

35 files changed

+1395
-512
lines changed

README.md

Lines changed: 21 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ a minute to compile.
5353

5454
Chatbot example:
5555

56-
python test_chatbot.py -d <path_to_model_files> -un "Jeff" -p prompt_chatbort.txt
56+
python example_chatbot.py -d <path_to_model_files> -un "Jeff" -p prompt_chatbort.txt
5757

5858
## Web UI
5959

@@ -120,11 +120,11 @@ docker run --gpus all -p 5000:5000 -v <path_to_model_files>:/app/model/ --rm -it
120120
|----------|------|-------|-----------------|----------------------|-----------|------------|---------|---------|------|
121121
| Llama | 7B | 128 | no | 2,048 t | 5,194 MB | 13,918 t/s | 173 t/s | 140 t/s | 6.45 |
122122
| Llama | 13B | 128 | no | 2,048 t | 9,127 MB | 7,507 t/s | 102 t/s | 86 t/s | 5.60 |
123-
| Llama | 30B | 128 | no | 2,048 t | 20,795 MB | 2,959 t/s | 47 t/s | 40 t/s | 4.60 |
124-
| Llama | 30B | 128 | yes | 2,048 t | 20,795 MB | 2,784 t/s | 45 t/s | 37 t/s | 4.55 |
125-
| Llama | 30B | 32 | yes | 1,550 t <sup>1</sup> | 21,486 MB | 2,636 t/s | 41 t/s | 37 t/s | 4.52 |
123+
| Llama | 33B | 128 | no | 2,048 t | 20,795 MB | 2,959 t/s | 47 t/s | 40 t/s | 4.60 |
124+
| Llama | 33B | 128 | yes | 2,048 t | 20,795 MB | 2,784 t/s | 45 t/s | 37 t/s | 4.55 |
125+
| Llama | 33B | 32 | yes | 1,550 t <sup>1</sup> | 21,486 MB | 2,636 t/s | 41 t/s | 37 t/s | 4.52 |
126126
| Koala | 13B | 128 | yes | 2,048 t | 9,127 MB | 5,529 t/s | 93 t/s | 79 t/s | 6.73 |
127-
| WizardLM | 30B | - | no <sup>2</sup> | 2,048 t | 20,199 MB | 2,313 t/s | 47 t/s | 40 t/s | 5.75 |
127+
| WizardLM | 33B | - | no <sup>2</sup> | 2,048 t | 20,199 MB | 2,313 t/s | 47 t/s | 40 t/s | 5.75 |
128128

129129
<sup>1</sup> Can not achieve full sequence length without OoM (yet)
130130
<sup>2</sup> Not quite sure if this is act-order or not. Weights have no group index, at least
@@ -156,16 +156,16 @@ following benchmarks are from a 4090 + 3090-Ti with `-gs 17.2,24`:
156156

157157
### Testing long sequences
158158

159-
The following tests were all done on **30B/65B, 4bit 128g** with various settings, just to test the max sequence length
159+
The following tests were all done on **33B/65B, 4bit 128g** with various settings, just to test the max sequence length
160160
and get a sense of what can be achieved with different or multiple GPUs right now. Llama goes incoherent generating
161161
past 2048 tokens anyway, but with some fine-tuning, who knows? Note that these tests were run a while ago and the
162162
speeds are no longer current.
163163

164164
| | Size | Seq. len. | VRAM | Long seq. | Ind. |
165165
|------------------------|------|-----------|----------------------|-----------|--------|
166-
| 4090/24GB | 30B | 2,516 t | 22,145 MB | 1140 t/s | 28 t/s |
167-
| 4090/24GB + 3070Ti/8GB | 30B | 3,932 t | 22,055 MB + 7,377 MB | 840 t/s | 22 t/s |
168-
| A6000/48GB (headless) | 30B | 9,032 t | 46,863 MB | 645 t/s | 12 t/s |
166+
| 4090/24GB | 33B | 2,516 t | 22,145 MB | 1140 t/s | 28 t/s |
167+
| 4090/24GB + 3070Ti/8GB | 33B | 3,932 t | 22,055 MB + 7,377 MB | 840 t/s | 22 t/s |
168+
| A6000/48GB (headless) | 33B | 9,032 t | 46,863 MB | 645 t/s | 12 t/s |
169169
| A100/80GB (headless) | 65B | 9,520 t | 79,009 MB | 650 t/s | 9 t/s |
170170

171171
## Todo
@@ -197,18 +197,24 @@ for individual tokens, but benchmarks updated anyway. Closing in on 10k tokens/s
197197
rewrite at some point to make the client-side code less seizure-inducing. It has multibot mode, chat rewind and editing
198198
features, sessions, and more. I'm going to build it out with support for instruct prompting and such, in time.
199199

200-
**2024-06-04**: Refactored a whole bunch to move more of the work into the extension, setting up for more tuning
200+
**2023-06-04**: Refactored a whole bunch to move more of the work into the extension, setting up for more tuning
201201
options to come soon and eventually auto tuning. Also optimized a little, for about a 5% speedup.
202202

203-
**2024-06-06**: Some minor optimizations. Also it should now compile the extension more easily and run more seamlessly
203+
**2023-06-06**: Some minor optimizations. Also it should now compile the extension more easily and run more seamlessly
204204
on Windows.
205205

206-
**2024-06-09**: Fused most of the self-attention step. More to come. Slight speedup already, but more importantly went
206+
**2023-06-09**: Fused most of the self-attention step. More to come. Slight speedup already, but more importantly went
207207
from 69% actual CPU utilization to 37%. This should do a lot to address the bottleneck on CPUs with lower
208208
single-threaded performance.
209209

210-
**2024-06-10**: Docker support now! And some minor optimizations. Cleaned up the project a bit.
210+
**2023-06-10**: Docker support now! And some minor optimizations. Cleaned up the project a bit.
211211

212-
**2024-06-11**: Added some concurrency a couple of places. It's only beneficial on the 4090, on small models where the
212+
**2023-06-11**: Added some concurrency a couple of places. It's only beneficial on the 4090, on small models where the
213213
cores are somewhat underutilized and the L2 cache can keep up. For the 3090 it's detrimental to performance, so it's
214-
disabled by default. YMMV. Use `-cs` to try it out.
214+
disabled by default. YMMV. Use `-cs` to try it out.
215+
216+
**2023-06-17**: Fixed a nasty bug in the fused attention that was causing slightly incorrect cache states on 13B and
217+
33B models. You definitely want to update.
218+
219+
**2023-06-18**: LoRA support now. Still needs a lot of testing and som optimization, and currently you can't stack
220+
multiple LoRAs during the same inference. There's also no support in the web UI yet.

cuda_ext.py

Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
# from abc import ABC
2+
import torch
3+
from torch.cuda.amp import custom_bwd, custom_fwd
4+
from torch.utils.cpp_extension import load
5+
import os
6+
import sys
7+
import platform
8+
9+
library_dir = os.path.dirname(os.path.abspath(__file__))
10+
extension_name = "exllama_ext"
11+
verbose = False
12+
13+
# another kludge to get things compiling in Windows
14+
windows = os.name == "nt"
15+
if windows:
16+
def find_msvc():
17+
for msvc_dir in [a + "\\Microsoft Visual Studio\\" + b + "\\" + c + "\\VC\Tools\\MSVC\\"
18+
for b in ["2022", "2019", "2017"]
19+
for a in [os.environ["ProgramW6432"], os.environ["ProgramFiles(x86)"]]
20+
for c in ["BuildTools", "Community", "Professional", "Enterprise", "Preview"]
21+
]:
22+
if not os.path.exists(msvc_dir):
23+
continue
24+
versions = sorted(os.listdir(msvc_dir), reverse=True)
25+
for version in versions:
26+
compiler_dir = msvc_dir + version + "\\bin\\Hostx64\\x64"
27+
if os.path.exists(compiler_dir) and os.path.exists(compiler_dir + "\\cl.exe"):
28+
return compiler_dir
29+
return None
30+
31+
import subprocess
32+
try:
33+
subprocess.check_output(["where", "cl"])
34+
except subprocess.CalledProcessError as e:
35+
cl_path = find_msvc()
36+
if cl_path:
37+
print("Injected compiler path:", cl_path)
38+
os.environ["path"] += ";" + cl_path
39+
else:
40+
print("Unable to find cl.exe; compilation will probably fail.")
41+
42+
exllama_ext = load(
43+
name = extension_name,
44+
sources = [
45+
os.path.join(library_dir, "exllama_ext/exllama_ext.cpp"),
46+
os.path.join(library_dir, "exllama_ext/cuda_buffers.cu"),
47+
os.path.join(library_dir, "exllama_ext/cuda_func/q4_matrix.cu"),
48+
os.path.join(library_dir, "exllama_ext/cuda_func/q4_matmul.cu"),
49+
os.path.join(library_dir, "exllama_ext/cuda_func/column_remap.cu"),
50+
os.path.join(library_dir, "exllama_ext/cuda_func/rms_norm.cu"),
51+
os.path.join(library_dir, "exllama_ext/cuda_func/rope.cu"),
52+
os.path.join(library_dir, "exllama_ext/cuda_func/half_matmul.cu"),
53+
os.path.join(library_dir, "exllama_ext/cuda_func/q4_attn.cu"),
54+
os.path.join(library_dir, "exllama_ext/cuda_func/q4_mlp.cu"),
55+
os.path.join(library_dir, "exllama_ext/cpu_func/rep_penalty.cpp")
56+
],
57+
extra_include_paths = [os.path.join(library_dir, "exllama_ext")],
58+
verbose = verbose,
59+
extra_ldflags = ["cublas.lib"] if windows else [],
60+
extra_cuda_cflags = ["-lineinfo"] + (["-U__HIP_NO_HALF_CONVERSIONS__", "-O3"] if torch.version.hip else []),
61+
extra_cflags = ["-O3"]
62+
# extra_cflags = ["-ftime-report", "-DTORCH_USE_CUDA_DSA"]
63+
)
64+
65+
# from exllama_ext import set_tuning_params
66+
# from exllama_ext import prepare_buffers
67+
from exllama_ext import make_q4
68+
from exllama_ext import q4_matmul
69+
from exllama_ext import q4_matmul_lora
70+
from exllama_ext import half_matmul
71+
from exllama_ext import half_matmul_cublas
72+
# from exllama_ext import q4_mlp
73+
from exllama_ext import rms_norm
74+
from exllama_ext import rope_
75+
from exllama_ext import rep_penalty
76+
77+
78+
# Dummy tensor to pass instead of g_idx since there is no way to pass "None" to a C++ extension
79+
80+
none_tensor = torch.empty((1, 1), device = "meta")
81+
82+
83+
# Construct Q4Matrix, return handle
84+
85+
def ext_make_q4(qweight, qzeros, scales, g_idx, device):
86+
87+
return make_q4(qweight,
88+
qzeros,
89+
scales,
90+
g_idx if g_idx is not None else none_tensor,
91+
device)
92+
93+
94+
# Matrix multiplication, returns x @ q4
95+
96+
def ext_q4_matmul(x, q4, q4_width, lora_A = None, lora_B = None):
97+
98+
outshape = x.shape[:-1] + (q4_width,)
99+
x = x.view(-1, x.shape[-1])
100+
output = torch.empty((x.shape[0], q4_width), dtype = torch.float16, device = x.device)
101+
102+
if lora_A is None:
103+
q4_matmul(x, q4, output)
104+
else:
105+
lora_temp = torch.empty((x.shape[0], lora_A.shape[1]), dtype = torch.float16, device = x.device)
106+
q4_matmul_lora(x, q4, output, lora_A, lora_B, lora_temp)
107+
108+
return output.view(outshape)
109+
110+
111+
# Matrix multiplication, returns x @ w, both half-precision tensors
112+
113+
def ext_half_matmul(x, w, cublas = False):
114+
115+
outshape = x.shape[:-1] + (w.shape[1],)
116+
x = x.view(-1, x.shape[-1])
117+
118+
if cublas:
119+
output = torch.empty((x.shape[0], w.shape[1]), dtype = torch.float16, device = x.device)
120+
half_matmul_cublas(x, w, output)
121+
else:
122+
output = torch.zeros((x.shape[0], w.shape[1]), dtype = torch.float16, device = x.device)
123+
half_matmul(x, w, output)
124+
125+
return output.view(outshape) ##
126+
127+
128+
# RoPE embeddings, in_place
129+
130+
def ext_rope_(x, sin, cos, past_len, num_heads, head_dim):
131+
132+
rope_(x, sin, cos, past_len, num_heads, head_dim)
133+
134+
135+
# RMS norm: x = x * w / sqrt(row_mean(x * x) + epsilon)
136+
137+
def ext_rms_norm(x, w, epsilon):
138+
139+
outshape = x.shape
140+
x = x.view(-1, x.shape[-1])
141+
output = torch.empty_like(x)
142+
rms_norm(x, w, output, epsilon)
143+
144+
return output.view(outshape)
145+
146+
def ext_rms_norm_(x, w, epsilon):
147+
148+
outshape = x.shape
149+
x = x.view(-1, x.shape[-1])
150+
rms_norm(x, w, x, epsilon)
151+
152+
153+
# Repetition penalty
154+
155+
def ext_rep_penalty_mask_cpu(vocab_size, sequence, penalty_max, sustain, decay):
156+
157+
rep_mask = torch.empty(vocab_size, dtype = torch.float32)
158+
rep_penalty(sequence, rep_mask, penalty_max, sustain, decay)
159+
return rep_mask

datasets/download_datasets.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
# import torch
2+
# from tokenizer import ExLlamaTokenizer
3+
from datasets import load_dataset
4+
import os
5+
6+
# Download samples from HF datasets to run equivalent GPTQ-for-LLaMa equivalent benchmark
7+
8+
def download_hf(filename, dataset, subset, split, key, div):
9+
10+
print(f"Downloading from {dataset}: {subset}, split: {split} ...")
11+
hf_dataset = load_dataset(dataset, subset, split = split)
12+
data = div.join(hf_dataset[key])
13+
14+
with open(filename, "w") as f:
15+
f.write(data)
16+
17+
download_hf("wikitext2.txt", "wikitext", "wikitext-2-raw-v1", "test", "text", "\n\n")
18+
download_hf("ptb.txt", "ptb_text_only", "penn_treebank", "validation", "sentence", "\n\n")
19+
download_hf("ptb_new.txt", "ptb_text_only", "penn_treebank", "test", "sentence", " ")
File renamed without changes.

doc/TODO.md

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,7 @@
33
- [x] Support for act-order models ~~(a bit slow for now)~~
44
- [x] ~~Support for v1 models without groupsize~~ Nah.
55
- [x] Test more models
6-
- [ ] Consider support for loading GGML models
7-
- [ ] Utility to scan and validate .safetensors files
6+
- [x] Consider support for loading GGML models (not feasible)
87
- [x] Figure out if there are quantized models with irregular groupsize (there are some at least with no groupsize)
98

109
## GPU compatibility (etc.)
@@ -22,8 +21,9 @@
2221

2322
## Testing
2423

25-
- [ ] Figure out an apples-to-apples way of comparing perplexity with other implementations
24+
- [x] Figure out an apples-to-apples way of comparing perplexity with other implementations
2625
- [ ] Compile charts of inference speed vs context length for variety of models, compare to other implementations
26+
- [ ] Test a bunch of LoRAs to make sure all combinations of rank and target layers work
2727

2828
## VRAM optimization
2929

@@ -41,27 +41,30 @@
4141
- [x] ~~Build attention mask in CUDA rather than PyTorch~~
4242
- [x] ~~Disable attention mask when it isn't needed~~ (not possible with SDP)
4343
- [x] Figure out why inference appears to be CPU-bound (kernel launch overhead)
44-
- [ ] Reduce no. kernel launches to minimum (tail launch, fusion etc.)
44+
- [x] Reduce no. kernel launches to minimum (tail launch, fusion etc.)
4545
- [x] Measure PyTorch module overhead (negligible in eval mode)
4646
- [x] Examine if scaled_dot_product_attention is actually the best attention method for single tokens (it's not)
4747
- [ ] Implement attention in CUDA
4848
- [x] Rewrite at least the quantized matmul kernel. Should be a bunch of special cases to consider
4949
- [x] Experiment with concurrent streams where possible (fused MLP and QKV proj.)
50+
- [x] Faster low-rank matmul to speed up LoRAs
5051

5152
## Generation
5253

5354
- [x] Memory-efficient beam search implementation
5455
- [ ] Optimized beam search
5556
- [ ] Multi-token censoring/de-censoring
5657
- [ ] Multi-token repetition penalties
57-
- [ ] (Multi) LoRA support
58+
- [x] (Multi) LoRA support
59+
- [ ] Allow stackable LoRAs
5860
- [x] Guided generation (chat with multiple bots at once, etc.)
5961
- [ ] Multiple chat modes with prompt templates (instruct, etc.)
62+
- [ ] Batched generation
6063

6164
## Interface
6265

6366
- [x] Simple web interface?
64-
- [ ] API server
67+
- [ ] API server
6568

6669
## Web UI
6770

@@ -71,9 +74,11 @@
7174
- [ ] Make it a little prettier
7275
- [ ] Test various edge cases
7376
- [ ] Better error handling
77+
- [ ] LoRA controls
7478

7579
## ??
7680

81+
- [ ] FP8/FP16 overlays
7782
- [ ] Allow for backpropagation
7883
- [ ] LoRA training features
7984
- [ ] Soft prompt training

doc/model_compatibility.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,16 +9,19 @@ As of **2023-05-24**, the following GPTQ models on HuggingFace all appear to be
99
- Neko-Institute-of-Science/LLaMA-65B-4bit-32g
1010
- Neko-Institute-of-Science/LLaMA-65B-4bit-128g
1111
- reeducator/bluemoonrp-13b
12+
- reeducator/bluemoonrp-30b
1213
- TehVenom/Metharme-13b-4bit-GPTQ
1314
- TheBloke/airoboros-13B-GPTQ
1415
- TheBloke/gpt4-x-vicuna-13B-GPTQ
1516
- TheBloke/GPT4All-13B-snoozy-GPTQ
1617
- TheBloke/guanaco-33B-GPTQ
18+
- TheBloke/guanaco-65B-GPTQ
1719
- TheBloke/h2ogpt-oasst1-512-30B-GPTQ <sup>1</sup>
1820
- TheBloke/koala-13B-GPTQ-4bit-128g
1921
- TheBloke/Manticore-13B-GPTQ
2022
- TheBloke/medalpaca-13B-GPTQ-4bit
2123
- TheBloke/medalpaca-13B-GPTQ-4bit (compat version)
24+
- TheBloke/Nous-Hermes-13B-GPTQ
2225
- TheBloke/tulu-30B-GPTQ
2326
- TheBloke/vicuna-13B-1.1-GPTQ-4bit-128g
2427
- TheBloke/VicUnlocked-30B-LoRA-GPTQ

example_basic.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
from model import ExLlama, ExLlamaCache, ExLlamaConfig
2+
from tokenizer import ExLlamaTokenizer
3+
from generator import ExLlamaGenerator
4+
import os, glob
5+
6+
# Directory containt model, tokenizer, generator
7+
8+
model_directory = "/mnt/str/models/llama-13b-4bit-128g/"
9+
10+
# Locate files we need within that directory
11+
12+
tokenizer_path = os.path.join(model_directory, "tokenizer.model")
13+
model_config_path = os.path.join(model_directory, "config.json")
14+
st_pattern = os.path.join(model_directory, "*.safetensors")
15+
model_path = glob.glob(st_pattern)[0]
16+
17+
# Create config, model, tokenizer and generator
18+
19+
config = ExLlamaConfig(model_config_path) # create config from config.json
20+
config.model_path = model_path # supply path to model weights file
21+
22+
model = ExLlama(config) # create ExLlama instance and load the weights
23+
tokenizer = ExLlamaTokenizer(tokenizer_path) # create tokenizer from tokenizer model file
24+
25+
cache = ExLlamaCache(model) # create cache for inference
26+
generator = ExLlamaGenerator(model, tokenizer, cache) # create generator
27+
28+
# Configure generator
29+
30+
generator.disallow_tokens([tokenizer.eos_token_id])
31+
32+
generator.settings.token_repetition_penalty_max = 1.2
33+
generator.settings.temperature = 0.95
34+
generator.settings.top_p = 0.65
35+
generator.settings.top_k = 100
36+
generator.settings.typical = 0.5
37+
38+
# Produce a simple generation
39+
40+
prompt = "Once upon a time,"
41+
print (prompt, end = "")
42+
43+
output = generator.generate_simple(prompt, max_new_tokens = 200)
44+
45+
print(output[len(prompt):])
File renamed without changes.

0 commit comments

Comments
 (0)