diff --git a/.gitignore b/.gitignore
index 9dfb94c6..67adfd8e 100644
--- a/.gitignore
+++ b/.gitignore
@@ -18,6 +18,8 @@ models/
*.bin
!llama_vocab.bin
!starcoder_vocab.bin
+!mistral_vocab.bin
+!llama3_vocab.bin
*.zip
*.txt
!requirements.txt
diff --git a/README.md b/README.md
index 0ef1d421..8fdbcd0f 100644
--- a/README.md
+++ b/README.md
@@ -1,20 +1,20 @@
![tinychat_logo](assets/figures/tinychat_logo.png)
-# TinyChatEngine: On-Device LLM Inference Library
+# TinyChatEngine: On-Device LLM/VLM Inference Library
-Running large language models (LLMs) on the edge is useful: copilot services (coding, office, smart reply) on laptops, cars, robots, and more. Users can get instant responses with better privacy, as the data is local.
+Running large language models (LLMs) and visual language models (VLMs) on the edge is useful: copilot services (coding, office, smart reply) on laptops, cars, robots, and more. Users can get instant responses with better privacy, as the data is local.
This is enabled by LLM model compression technique: [SmoothQuant](https://github.com/mit-han-lab/smoothquant) and [AWQ (Activation-aware Weight Quantization)](https://github.com/mit-han-lab/llm-awq), co-designed with TinyChatEngine that implements the compressed low-precision model.
Feel free to check out our [slides](assets/slides.pdf) for more details!
-### Code LLaMA Demo on an NVIDIA GeForce RTX 4070 laptop:
+### Code LLaMA Demo on NVIDIA GeForce RTX 4070 laptop:
![coding_demo_gpu](assets/figures/coding_demo_gpu.gif)
-### VILA Demo on an Apple MacBook Pro (M1, 2021):
+### VILA Demo on Apple MacBook M1 Pro:
![vlm_demo_m1](assets/figures/vlm_demo_m1.gif)
-### LLaMA Chat Demo on an Apple MacBook Pro (M1, 2021):
+### LLaMA Chat Demo on Apple MacBook M1 Pro:
![chat_demo_m1](assets/figures/chat_demo_m1.gif)
@@ -37,7 +37,10 @@ Feel free to check out our [slides](assets/slides.pdf) for more details!
## News
-- **(2024/02)** 🔥We extended the support for vision language models (VLM). Feel free to try running [VILA](#deploy-vision-language-model-vlm-chatbot-with-tinychatengine) on your edge device.
+- **(2024/05)** 🏆 AWQ and TinyChat received the **Best Paper Award** at **MLSys 2024**. 🎉
+- **(2024/05)** 🔥 We released the support for the **Llama-3** model family! Check out our example [here](#step-by-step-to-deploy-llama-3-8b-instruct-with-tinychatengine).
+- **(2024/02)** 🔥AWQ and TinyChat has been accepted to **MLSys 2024**!
+- **(2024/02)** 🔥We extended the support for **vision language models (VLM)**. Feel free to try running **[VILA](#deploy-vision-language-model-vlm-chatbot-with-tinychatengine)** on your edge device.
- **(2023/10)** We extended the support for the coding assistant [Code Llama](#download-and-deploy-models-from-our-model-zoo). Feel free to check out.
- **(2023/10)** ⚡We released the new CUDA backend to support Nvidia GPUs with compute capability >= 6.1 for both server and edge GPUs. Its performance is also speeded up by ~40% compared to the previous version. Feel free to check out!
@@ -77,9 +80,9 @@ pacman -S --needed base-devel mingw-w64-x86_64-toolchain make unzip git
- Follow the instructions below and use x64 Native Tools Command Prompt from Visual Studio to compile TinyChatEngine.
-## Step-by-step to Deploy LLaMA2-7B-chat with TinyChatEngine
+## Step-by-step to Deploy Llama-3-8B-Instruct with TinyChatEngine
-Here, we provide step-by-step instructions to deploy LLaMA2-7B-chat with TinyChatEngine from scratch.
+Here, we provide step-by-step instructions to deploy Llama-3-8B-Instruct with TinyChatEngine from scratch.
- Download the repo.
```bash
@@ -94,17 +97,17 @@ Here, we provide step-by-step instructions to deploy LLaMA2-7B-chat with TinyCha
conda activate TinyChatEngine
pip install -r requirements.txt
```
-- Download the quantized LLaMA2-7B-chat model from our model zoo.
+- Download the quantized Llama model from our model zoo.
```bash
cd llm
```
- On an x86 device (e.g., Intel/AMD laptop)
```bash
- python tools/download_model.py --model LLaMA2_7B_chat_awq_int4 --QM QM_x86
+ python tools/download_model.py --model LLaMA_3_8B_Instruct_awq_int4 --QM QM_x86
```
- On an ARM device (e.g., M1/M2 Macbook, Raspberry Pi)
```bash
- python tools/download_model.py --model LLaMA2_7B_chat_awq_int4 --QM QM_ARM
+ python tools/download_model.py --model LLaMA_3_8B_Instruct_awq_int4 --QM QM_ARM
```
- On a CUDA device (e.g., Jetson AGX Orin, PC/Server)
```bash
@@ -118,30 +121,28 @@ Here, we provide step-by-step instructions to deploy LLaMA2-7B-chat with TinyCha
./chat
TinyChatEngine by MIT HAN Lab: https://github.com/mit-han-lab/TinyChatEngine
- Using model: LLaMA2_7B_chat
+ Using model: LLaMA_3_8B_Instruct
Using AWQ for 4bit quantization: https://github.com/mit-han-lab/llm-awq
Loading model... Finished!
- USER: Write a syllabus for Operating Systems.
- ASSISTANT:
- Of course! Here is a sample syllabus for a college-level course on operating systems:
- Course Title: Introduction to Operating Systems
- Course Description: This course provides an overview of the fundamental concepts and techniques used in modern operating systems, including process management, memory management, file systems, security, and I/O devices. Students will learn how these components work together to provide a platform for running applications and programs on a computer.
- Course Objectives:
- * Understand the basic architecture of an operating system
- * Learn about processes, threads, and process scheduling algorithms
- * Study memory management techniques such as paging and segmentation
- * Explore file systems including file organization, storage devices, and file access methods
- * Investigate security mechanisms to protect against malicious software attacks
- * Analyze input/output (I/O) operations and their handling by the operating system
+ USER: Write a syllabus for the parallel computing course.
+ ASSISTANT: Here is a sample syllabus for a parallel computing course:
+
+ **Course Title:** Parallel Computing
+ **Instructor:** [Name]
+ **Description:** This course covers the fundamental concepts of parallel computing, including parallel algorithms, programming models, and architectures. Students will learn how to design, implement, and optimize parallel programs using various languages and frameworks.
+ **Prerequisites:** Basic knowledge of computer science and programming concepts.
+ **Course Objectives:**
+ * Understand the principles of parallelism and its applications
+ * Learn how to write parallel programs using different languages (e.g., OpenMP, MPI)
...
```
TinyChatEngine supports not only LLM but also VLM. We introduce a sophisticated chatbot for VLM. Here, we provide easy-to-follow instructions to deploy vision language model chatbot (VILA-7B) with TinyChatEngine. We recommend using M1/M2 MacBooks for this VLM feature.
-- Follow the instructions above to setup the basic environment, i.e., [Prerequisites](#prerequisites) and [Step-by-step to Deploy LLaMA2-7B-chat with TinyChatEngine](#step-by-step-to-deploy-llama2-7b-chat-with-tinychatengine).
+- Follow the instructions above to setup the basic environment, i.e., [Prerequisites](#prerequisites) and [Step-by-step to Deploy Llama-3-8B-Instruct with TinyChatEngine](#step-by-step-to-deploy-llama-3-8b-instruct-with-tinychatengine).
- To demonstrate images in the terminal, please download and install the following toolkit.
- Install [termvisage](https://github.com/AnonymouX47/termvisage).
@@ -204,11 +205,11 @@ TinyChatEngine supports not only LLM but also VLM. We introduce a sophisticated
## Backend Support
-| Precision | x86
(Intel/AMD CPU) | ARM
(Apple M1/M2 & RPi) | Nvidia GPU | Apple GPU |
-| ------ | --------------------------- | --------- | --------- | --------- |
+| Precision | x86
(Intel/AMD CPU) | ARM
(Apple M1/M2 & RPi) | Nvidia GPU |
+| ------ | --------------------------- | --------- | --------- |
| FP32 | ✅ | ✅ | |
-| W4A16 | | | ✅ | ✅
-| W4A32 | ✅ | ✅ | | ✅
+| W4A16 | | | ✅ |
+| W4A32 | ✅ | ✅ | |
| W4A8 | ✅ | ✅ | |
| W8A8 | ✅ | ✅ | |
@@ -247,6 +248,22 @@ We offer a selection of models that have been tested with TinyChatEngine. These
+
+ LLaMA_3_8B_Instruct |
+ fp32 |
+ LLaMA_3_8B_Instruct_fp32 |
+ ✅ |
+ ✅ |
+ |
+
+
+
+ int4 |
+ LLaMA_3_8B_Instruct_awq_int4 |
+ ✅ |
+ ✅ |
+ |
+
LLaMA2_13B_chat |
fp32 |
@@ -327,6 +344,22 @@ We offer a selection of models that have been tested with TinyChatEngine. These
✅ |
✅ |
+
+ Mistral-7B-Instruct-v0.2 |
+ fp32 |
+ Mistral_7B_v0.2_Instruct_fp32 |
+ ✅ |
+ ✅ |
+ |
+
+
+
+ int4 |
+ Mistral_7B_v0.2_Instruct_awq_int4 |
+ ✅ |
+ ✅ |
+ |
+
VILA-7B |
fp32 |
diff --git a/assets/figures/vlm_demo/CPR.jpg b/assets/figures/vlm_demo/CPR.jpg
new file mode 100644
index 00000000..d7793d8b
Binary files /dev/null and b/assets/figures/vlm_demo/CPR.jpg differ
diff --git a/assets/figures/vlm_demo/Wall_fissure.png b/assets/figures/vlm_demo/Wall_fissure.png
new file mode 100644
index 00000000..b50b9468
Binary files /dev/null and b/assets/figures/vlm_demo/Wall_fissure.png differ
diff --git a/assets/figures/vlm_demo/car.png b/assets/figures/vlm_demo/car.png
new file mode 100644
index 00000000..5c54fd5d
Binary files /dev/null and b/assets/figures/vlm_demo/car.png differ
diff --git a/assets/figures/vlm_demo/pedestrian.png b/assets/figures/vlm_demo/pedestrian.png
index f42fd7d0..a34505d5 100755
Binary files a/assets/figures/vlm_demo/pedestrian.png and b/assets/figures/vlm_demo/pedestrian.png differ
diff --git a/assets/figures/vlm_demo/statue.jpg b/assets/figures/vlm_demo/statue.jpg
new file mode 100644
index 00000000..18513b1c
Binary files /dev/null and b/assets/figures/vlm_demo/statue.jpg differ
diff --git a/kernels/matmul.h b/kernels/matmul.h
index 8c186ad4..0424edee 100644
--- a/kernels/matmul.h
+++ b/kernels/matmul.h
@@ -99,8 +99,12 @@ struct thread_args {
int start_i, end_i, blk_size;
};
+#ifndef MAX
#define MAX(A, B) ((A) > (B) ? (A) : (B))
+#endif
+#ifndef MIN
#define MIN(A, B) ((A) < (B) ? (A) : (B))
+#endif
namespace matmul {
class MatmulOperator {
diff --git a/llm/application/chat.cc b/llm/application/chat.cc
index c98ee7ff..f2c0087a 100644
--- a/llm/application/chat.cc
+++ b/llm/application/chat.cc
@@ -12,8 +12,8 @@ std::map model_config = {
{"CodeLLaMA_7B_Instruct", CodeLLaMA_7B}, {"CodeLLaMA_13B_Instruct", CodeLLaMA_13B},
{"StarCoder", StarCoder_15_5B}, {"StarCoder_15.5B", StarCoder_15_5B}, {"LLaVA_7B", LLaVA_7B}, {"LLaVA_13B", LLaVA_13B},
{"VILA_2.7B", VILA_2_7B}, {"VILA_7B", VILA_7B}, {"VILA_13B", VILA_13B}, {"Clip_ViT_Large", Clip_ViT_Large},
- {"Mistral_7B", Mistral_7B}
- };
+ {"Mistral_7B", Mistral_7B}, {"LLaMA_3_8B_Instruct", LLaMA_3_8B}, {"VILA1.5_8B", VILA1_5_8B},
+};
std::map model_path = {{"OPT_125m", "models/OPT_125m"},
{"OPT_1.3B", "models/OPT_1.3B"},
@@ -34,12 +34,22 @@ std::map model_path = {{"OPT_125m", "models/OPT_125m"}
{"VILA_13B", "models/VILA_13B"},
{"Clip_ViT_Large", "models/CLIP_ViT_Large"},
{"Mistral_7B", "models/Mistral_7B"},
+ {"LLaMA_3_8B_Instruct", "models/LLaMA_3_8B_Instruct"},
+ {"VILA1.5_8B", "models/VILA1.5_8B"},
};
std::map data_format_list = {
{"FP32", FP32}, {"INT8", QINT8}, {"INT4", INT4}, {"int4", INT4}, {"fp32", FP32},
};
+bool isLLaMA3(std::string s) {
+ std::string LLaMA_prefix = "LLaMA_3";
+ if (s.substr(0, LLaMA_prefix.size()) == LLaMA_prefix)
+ return true;
+ else
+ return false;
+}
+
bool isLLaMA(std::string s) {
std::string LLaMA_prefix = "LLaMA";
std::string CodeLLaMA_prefix = "CodeLLaMA";
@@ -73,6 +83,14 @@ bool isLLaVA(std::string s) {
return false;
}
+bool isVILA1_5(std::string s) {
+ std::string VILA_prefix = "VILA1.5";
+ if (s.substr(0, VILA_prefix.size()) == VILA_prefix)
+ return true;
+ else
+ return false;
+}
+
bool isVILA(std::string s) {
std::string VILA_prefix = "VILA";
if (s.substr(0, VILA_prefix.size()) == VILA_prefix)
@@ -120,7 +138,7 @@ int main(int argc, char* argv[]) {
}
}
- std::string target_model = "LLaMA2_7B_chat";
+ std::string target_model = "LLaMA_3_8B_Instruct";
std::string target_data_format = "INT4";
bool instruct = true;
std::string img_path = "images/monalisa.jpg";
@@ -138,7 +156,7 @@ int main(int argc, char* argv[]) {
NUM_THREAD = atoi(argv[3]);
}
if (argc == 5) {
- if (isCodeLLaMA(target_model)) {
+ if (isCodeLLaMA(target_model) or isMistral(target_model)) {
instruct = convertToBool(argv[4]);
}
else if (isLLaVA(target_model) || isVILA(target_model)) {
@@ -203,7 +221,134 @@ int main(int argc, char* argv[]) {
}
}
- if (isLLaMA(target_model)) {
+ if (isLLaMA3(target_model)) {
+ int format_id = data_format_list[target_data_format];
+
+ // Voicechat instructions
+ if (use_voicechat) {
+ std::cout << "You are using the TinyVoiceChat." << std::endl;
+ std::cout << "*Usage instructions*" << std::endl;
+ std::cout << "- Please use this mode in a quiet environment to have a better user experience and avoid speech misdetection." << std::endl;
+ std::cout << "- Please start speaking after \"USER: [Start speaking]\" shows up." << std::endl;
+ std::cout << "- Please press `Ctrl+C` multiple times to exit the program." << std::endl << std::endl;
+ }
+
+ // Load model
+ std::cout << "Loading model... " << std::flush;
+ int model_id = model_config[target_model];
+ std::string m_path = model_path[target_model];
+
+ #ifdef MODEL_PREFIX
+ m_path = MODEL_PREFIX + m_path;
+ #endif
+
+ struct opt_params generation_config;
+ generation_config.n_predict = 2048;
+ generation_config.repeat_penalty = 1.1f;
+ generation_config.temp = 0.7f;
+ generation_config.n_vocab = 128256;
+ generation_config.top_p = 0.9f;
+
+ bool first_prompt = true;
+
+ if (format_id == FP32) {
+ Fp32LlamaForCausalLM model = Fp32LlamaForCausalLM(m_path, get_opt_model_config(model_id));
+ std::cout << "Finished!" << std::endl << std::endl;
+
+ // Get input from the user
+ while (true) {
+ std::string input;
+ if (use_voicechat) {
+ // Set prompt color
+ set_print_yellow();
+ int result = std::system("./application/sts_utils/listen");
+ std::ifstream in("tmpfile");
+ // set user input color
+ set_print_red();
+ std::getline(in, input);
+ result = std::system("rm tmpfile");
+ (void)result;
+ std::cout << input << std::endl;
+ // reset color
+ set_print_reset();
+ } else {
+ // Set prompt color
+ set_print_yellow();
+ std::cout << "USER: ";
+ // set user input color
+ set_print_red();
+ std::getline(std::cin, input);
+ // reset color
+ set_print_reset();
+ }
+ if (input == "quit" || input == "Quit" || input == "Quit." || input == "quit.")
+ break;
+ if (instruct) {
+ std::cout << "ASSISTANT: ";
+ }
+
+ if (first_prompt) {
+ input = "A chat between a curious human (\"Human\") and an artificial intelligence assistant (\"Assistant\"). The assistant gives detailed, helpful, and polite answers to the human's questions.\n\nHuman: " + input + "\nAssistant: ";
+ first_prompt = false;
+ }
+ else {
+ input = "Human: " + input + "\nAssistant: \n";
+ }
+
+ LLaMA3Generate(m_path, &model, LLaMA_FP32, input, generation_config, "models/llama3_vocab.bin", true, false);
+ }
+ } else if (format_id == INT4) {
+ m_path = "INT4/" + m_path;
+ Int4LlamaForCausalLM model = Int4LlamaForCausalLM(m_path, get_opt_model_config(model_id));
+ std::cout << "Finished!" << std::endl << std::endl;
+
+ // Get input from the user
+ while (true) {
+ std::string input;
+ if (use_voicechat) {
+ // Set prompt color
+ set_print_yellow();
+ int result = std::system("./application/sts_utils/listen");
+ std::ifstream in("tmpfile");
+ // set user input color
+ set_print_red();
+ std::getline(in, input);
+ result = std::system("rm tmpfile");
+ (void)result;
+ std::cout << input << std::endl;
+ // reset color
+ set_print_reset();
+ } else {
+ // Set prompt color
+ set_print_yellow();
+ std::cout << "USER: ";
+ // set user input color
+ set_print_red();
+ std::getline(std::cin, input);
+ // reset color
+ set_print_reset();
+ }
+ if (input == "quit" || input == "Quit" || input == "Quit." || input == "quit.")
+ break;
+ if (instruct) {
+ std::cout << "ASSISTANT: ";
+ }
+
+ if (first_prompt) {
+ input = "A chat between a curious human (\"Human\") and an artificial intelligence assistant (\"Assistant\"). The assistant gives detailed, helpful, and polite answers to the human's questions.\n\nHuman: " + input + "\nAssistant: ";
+ first_prompt = false;
+ }
+ else {
+ input = "Human: " + input + "\nAssistant: \n";
+ }
+
+ LLaMA3Generate(m_path, &model, LLaMA_INT4, input, generation_config, "models/llama3_vocab.bin", true, use_voicechat);
+ }
+ } else {
+ std::cout << std::endl;
+ std::cerr << "At this time, we only support FP32 and INT4 for LLaMA_3_8B_Instruct." << std::endl;
+ }
+ } else if (isLLaMA(target_model)) {
int format_id = data_format_list[target_data_format];
// Voicechat instructions
@@ -289,7 +434,7 @@ int main(int argc, char* argv[]) {
if (!isCodeLLaMA(target_model)) {
if (first_prompt) {
- input = "A chat between a human and an assistant.\n\n### Human: " + input + "\n### Assistant: \n";
+ input = "A chat between a curious human (\"Human\") and an artificial intelligence assistant (\"Assistant\"). The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n### Human: " + input + "\n### Assistant: ";
first_prompt = false;
}
else {
@@ -352,7 +497,7 @@ int main(int argc, char* argv[]) {
if (!isCodeLLaMA(target_model)) {
if (first_prompt) {
- input = "A chat between a human and an assistant.\n\n### Human: " + input + "\n### Assistant: \n";
+ input = "A chat between a curious human (\"Human\") and an artificial intelligence assistant (\"Assistant\"). The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n### Human: " + input + "\n### Assistant: ";
first_prompt = false;
}
else {
@@ -578,6 +723,158 @@ int main(int argc, char* argv[]) {
std::cout << std::endl;
std::cerr << "At this time, we only support FP32 and INT4 for LLaVA_7B." << std::endl;
}
+ } else if (isVILA1_5(target_model)) {
+ int format_id = data_format_list[target_data_format];
+
+ // Voicechat instructions
+ if (use_voicechat) {
+ std::cout << "You are using the TinyVoiceChat." << std::endl;
+ std::cout << "*Usage instructions*" << std::endl;
+ std::cout << "- Please use this mode in a quiet environment to have a better user experience and avoid speech misdetection." << std::endl;
+ std::cout << "- Please start speaking after \"USER: [Start speaking]\" shows up." << std::endl;
+ std::cout << "- Please press `Ctrl+C` multiple times to exit the program." << std::endl << std::endl;
+ }
+
+ // Load model
+ std::cout << "Loading model... " << std::flush;
+ std::string clip_m_path = model_path["Clip_ViT_Large"];
+ std::string llama_m_path = model_path[target_model];
+
+ int clip_model_id = model_config["Clip_ViT_Large"];
+ int llama_model_id = model_config[target_model];
+
+ #ifdef MODEL_PREFIX
+ llama_m_path = MODEL_PREFIX + llama_m_path;
+ #endif
+
+ struct opt_params generation_config;
+ generation_config.n_predict = 512;
+ generation_config.repeat_penalty = 1.1f;
+ generation_config.temp = 0.2f;
+ generation_config.n_vocab = 32000;
+ generation_config.top_p = 1.0f;
+
+ int prompt_iter = 0;
+
+ if (format_id == FP32) {
+ Fp32CLIPVisionTransformer clip_model = Fp32CLIPVisionTransformer(clip_m_path, get_opt_model_config(clip_model_id), true);
+ Fp32LlamaForCausalLM llama_model = Fp32LlamaForCausalLM(llama_m_path, get_opt_model_config(llama_model_id));
+
+ // Get input from the user
+ while (true) {
+ std::string input;
+ if (prompt_iter == 1) {
+ // Set prompt color
+ set_print_yellow();
+ std::cout << "Finished!" << std::endl << std::endl;
+ // reset color
+ set_print_reset();
+ }
+ if (prompt_iter > 0) {
+ if (use_voicechat) {
+ // Set prompt color
+ set_print_yellow();
+ int result = std::system("./application/sts_utils/listen");
+ std::ifstream in("tmpfile");
+ // set user input color
+ set_print_red();
+ std::getline(in, input);
+ result = std::system("rm tmpfile");
+ (void)result;
+ std::cout << input << std::endl;
+ // reset color
+ set_print_reset();
+ } else {
+ // Set prompt color
+ set_print_yellow();
+ std::cout << "USER: ";
+ // set user input color
+ set_print_red();
+ std::getline(std::cin, input);
+ // reset color
+ set_print_reset();
+ }
+ if (input == "quit" || input == "Quit" || input == "Quit." || input == "quit.")
+ break;
+ std::cout << "ASSISTANT: ";
+ }
+
+ if (prompt_iter == 0) {
+ input = "This is a chat between a user and an assistant.\n\n### USER: ";
+ prompt_iter += 1;
+ }
+ else if (prompt_iter == 1) {
+ input = "\n" + input + "\n### ASSISTANT:";
+ prompt_iter += 1;
+ }
+ else {
+ input = "### USER: " + input + "\n### ASSISTANT: \n";
+ }
+
+ LLaVAGenerate(llama_m_path, &llama_model, clip_m_path, &clip_model, VILA_FP32, input, img_path, generation_config, "models/llama_vocab.bin", true, false, true);
+ }
+ } else if (format_id == INT4) {
+ Fp32CLIPVisionTransformer clip_model = Fp32CLIPVisionTransformer(clip_m_path, get_opt_model_config(clip_model_id), true);
+ llama_m_path = "INT4/" + llama_m_path;
+ Int4LlamaForCausalLM llama_model = Int4LlamaForCausalLM(llama_m_path, get_opt_model_config(llama_model_id));
+
+ // Get input from the user
+ while (true) {
+ if (prompt_iter == 1) {
+ // Set prompt color
+ set_print_yellow();
+ std::cout << "Finished!" << std::endl << std::endl;
+ // reset color
+ set_print_reset();
+ }
+ std::string input;
+ if (prompt_iter > 0) {
+ if (use_voicechat) {
+ // Set prompt color
+ set_print_yellow();
+ int result = std::system("./application/sts_utils/listen");
+ std::ifstream in("tmpfile");
+ // set user input color
+ set_print_red();
+ std::getline(in, input);
+ result = std::system("rm tmpfile");
+ (void)result;
+ std::cout << input << std::endl;
+ // reset color
+ set_print_reset();
+ } else {
+ // Set prompt color
+ set_print_yellow();
+ std::cout << "USER: ";
+ // set user input color
+ set_print_red();
+ std::getline(std::cin, input);
+ // reset color
+ set_print_reset();
+ }
+ if (input == "quit" || input == "Quit" || input == "Quit." || input == "quit.")
+ break;
+ std::cout << "ASSISTANT: ";
+ }
+
+ if (prompt_iter == 0) {
+ input = "This is a chat between a user and an assistant.\n\n### USER: ";
+ prompt_iter += 1;
+ }
+ else if (prompt_iter == 1) {
+ input = "\n" + input + "\n### ASSISTANT:";
+ prompt_iter += 1;
+ }
+ else {
+ input = "### USER: " + input + "\n### ASSISTANT: \n";
+ }
+
+ LLaVAGenerate(llama_m_path, &llama_model, clip_m_path, &clip_model, VILA_INT4, input, img_path, generation_config, "models/llama_vocab.bin", true, use_voicechat, true);
+ }
+ } else {
+ std::cout << std::endl;
+ std::cerr << "At this time, we only support FP32 and INT4 for VILA1.5_8B." << std::endl;
+ }
} else if (isVILA(target_model)) {
int format_id = data_format_list[target_data_format];
@@ -791,19 +1088,27 @@ int main(int argc, char* argv[]) {
}
if (input == "quit" || input == "Quit" || input == "Quit." || input == "quit.")
break;
- if (instruct) {
- std::cout << "ASSISTANT: ";
- }
- if (first_prompt) {
- input = "A chat between a curious human (\"Human\") and an artificial intelligence assistant (\"Assistant\"). The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n### Human: " + input + "\n### Assistant: ";
- first_prompt = false;
- }
- else {
- input = "### Human: " + input + "\n### Assistant: \n";
+ std::cout << "ASSISTANT: ";
+ if (instruct) {
+ if (first_prompt) {
+ input = "[INST] " + input + " [/INST] ";
+ first_prompt = false;
+ }
+ else {
+ input = " [INST] " + input + " [/INST] ";
+ }
+ } else {
+ if (first_prompt) {
+ input = "A chat between a curious human (\"Human\") and an artificial intelligence assistant (\"Assistant\"). The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n### Human: " + input + "\n### Assistant: ";
+ first_prompt = false;
+ }
+ else {
+ input = "### Human: " + input + "\n### Assistant: \n";
+ }
}
- LLaMAGenerate(m_path, &model, LLaMA_FP32, input, generation_config, "models/llama_vocab.bin", true, false);
+ MistralGenerate(m_path, &model, LLaMA_FP32, input, generation_config, "models/mistral_vocab.bin", true, false);
}
} else if (format_id == INT4) {
m_path = "INT4/" + m_path;
@@ -838,19 +1143,27 @@ int main(int argc, char* argv[]) {
}
if (input == "quit" || input == "Quit" || input == "Quit." || input == "quit.")
break;
- if (instruct) {
- std::cout << "ASSISTANT: ";
- }
- if (first_prompt) {
- input = "A chat between a curious human (\"Human\") and an artificial intelligence assistant (\"Assistant\"). The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n### Human: " + input + "\n### Assistant: ";
- first_prompt = false;
- }
- else {
- input = "### Human: " + input + "\n### Assistant: \n";
+ std::cout << "ASSISTANT: ";
+ if (instruct) {
+ if (first_prompt) {
+ input = "[INST] " + input + " [/INST] ";
+ first_prompt = false;
+ }
+ else {
+ input = " [INST] " + input + " [/INST] ";
+ }
+ } else {
+ if (first_prompt) {
+ input = "A chat between a curious human (\"Human\") and an artificial intelligence assistant (\"Assistant\"). The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n### Human: " + input + "\n### Assistant: ";
+ first_prompt = false;
+ }
+ else {
+ input = "### Human: " + input + "\n### Assistant: \n";
+ }
}
- LLaMAGenerate(m_path, &model, LLaMA_INT4, input, generation_config, "models/llama_vocab.bin", true, use_voicechat);
+ MistralGenerate(m_path, &model, LLaMA_INT4, input, generation_config, "models/mistral_vocab.bin", true, use_voicechat);
}
} else {
std::cout << std::endl;
diff --git a/llm/chat-13b b/llm/chat_llama2-13b
similarity index 100%
rename from llm/chat-13b
rename to llm/chat_llama2-13b
diff --git a/llm/chat_llama2-7b b/llm/chat_llama2-7b
new file mode 100755
index 00000000..e9b7b900
--- /dev/null
+++ b/llm/chat_llama2-7b
@@ -0,0 +1,2 @@
+# !/bin/bash
+./chat LLaMA2_7B_chat INT4 5
diff --git a/llm/include/Generate.h b/llm/include/Generate.h
index bfd54c09..63e1aa4f 100644
--- a/llm/include/Generate.h
+++ b/llm/include/Generate.h
@@ -112,4 +112,10 @@ std::string LLaVAGenerate(std::string llama_param_path, void* llama_model_ptr, s
std::string text, std::string img_path, const struct opt_params generation_config, std::string voc_path, bool interactive,
bool voicechat, bool is_vila);
+std::string MistralGenerate(std::string param_path, void* model, int model_type, std::string text, const struct opt_params generation_config,
+ std::string voc_path, bool interactive, bool voicechat);
+
+std::string LLaMA3Generate(std::string param_path, void* model, int model_type, std::string text, const struct opt_params generation_config,
+ std::string voc_path, bool interactive, bool voicechat);
+
#endif // GENERATE_H
diff --git a/llm/include/model.h b/llm/include/model.h
index aa08dff6..7554aded 100644
--- a/llm/include/model.h
+++ b/llm/include/model.h
@@ -5,6 +5,7 @@
struct model_config {
int batch;
int num_heads;
+ int num_kv_heads;
int num_layers;
int max_sqlen;
int embed_dim;
@@ -30,6 +31,19 @@ struct model_config {
vocsize(vocsize),
padding_idx(padding_idx),
rms_norm_eps(rms_norm_eps) {}
+ // GQA/MQA models
+ model_config(int batch, int num_heads, int num_kv_heads, int num_layers, int max_sqlen, int embed_dim, int hidden_dim, int vocsize,
+ int padding_idx, float rms_norm_eps)
+ : batch(batch),
+ num_heads(num_heads),
+ num_kv_heads(num_kv_heads),
+ num_layers(num_layers),
+ max_sqlen(max_sqlen),
+ embed_dim(embed_dim),
+ hidden_dim(hidden_dim),
+ vocsize(vocsize),
+ padding_idx(padding_idx),
+ rms_norm_eps(rms_norm_eps) {}
// Clip models
model_config(int batch, int num_heads, int num_layers, int max_sqlen, int embed_dim, int hidden_dim, int vocsize,
int padding_idx, float rms_norm_eps, int image_size, int patch_size, int projection_dim, int mmproj_dim)
@@ -48,24 +62,25 @@ struct model_config {
mmproj_dim(mmproj_dim) {}
};
-enum { OPT_125M, OPT_1_3B, OPT_6_7B, LLaMA_7B, LLaMA_13B, CodeLLaMA_7B, CodeLLaMA_13B, StarCoder_15_5B, LLaVA_7B, LLaVA_13B, VILA_2_7B, VILA_7B, VILA_13B, Clip_ViT_Large, Mistral_7B};
+enum { OPT_125M, OPT_1_3B, OPT_6_7B, LLaMA_7B, LLaMA_13B, CodeLLaMA_7B, CodeLLaMA_13B, StarCoder_15_5B, LLaVA_7B, LLaVA_13B, VILA_2_7B, VILA_7B, VILA_13B, Clip_ViT_Large, Mistral_7B, LLaMA_3_8B, VILA1_5_8B };
enum { FP32, QINT8, INT4 };
const struct model_config opt_6_7B(1, 32, 32, 2048, 4096, 16384, 50272, 1, 0);
const struct model_config opt_1_3B(1, 32, 24, 2048, 2048, 8192, 50272, 1, 0);
const struct model_config opt_125m(1, 12, 12, 2048, 768, 3072, 50272, 1, 0);
-const struct model_config llama_7B(1, 32, 32, 2048, 4096, 11008, 32000, 1, 1e-6);
-const struct model_config llama_13B(1, 40, 40, 2048, 5120, 13824, 32000, 1, 1e-6);
-const struct model_config codellama_7B(1, 32, 32, 2048, 4096, 11008, 32016, 1, 1e-5);
-const struct model_config codellama_13B(1, 40, 40, 2048, 5120, 13824, 32016, 1, 1e-5);
+const struct model_config llama_7B(1, 32, 32, 32, 2048, 4096, 11008, 32000, 1, 1e-6);
+const struct model_config llama_13B(1, 40, 40, 40, 2048, 5120, 13824, 32000, 1, 1e-6);
+const struct model_config codellama_7B(1, 32, 32, 32, 2048, 4096, 11008, 32016, 1, 1e-5);
+const struct model_config codellama_13B(1, 40, 40, 40, 2048, 5120, 13824, 32016, 1, 1e-5);
const struct model_config starcoder_15_5B(1, 48, 40, 2048, 6144, 24576, 49152, 1, 0);
-const struct model_config llava_7B(1, 32, 32, 2048, 4096, 11008, 32000, 1, 1e-5);
-const struct model_config llava_13B(1, 40, 40, 2048, 5120, 13824, 32000, 1, 1e-5);
-const struct model_config vila_2_7B(1, 20, 32, 2048, 2560, 6912, 32000, 1, 1e-5);
-const struct model_config vila_7B(1, 32, 32, 2048, 4096, 11008, 32000, 1, 1e-5);
-const struct model_config vila_13B(1, 40, 40, 2048, 5120, 13824, 32000, 1, 1e-5);
+const struct model_config llava_7B(1, 32, 32, 32, 2048, 4096, 11008, 32000, 1, 1e-5);
+const struct model_config llava_13B(1, 40, 40, 40, 2048, 5120, 13824, 32000, 1, 1e-5);
+const struct model_config vila_2_7B(1, 20, 20, 32, 2048, 2560, 6912, 32000, 1, 1e-5);
+const struct model_config vila_7B(1, 32, 32, 32, 2048, 4096, 11008, 32000, 1, 1e-5);
+const struct model_config vila_13B(1, 40, 40, 40, 2048, 5120, 13824, 32000, 1, 1e-5);
const struct model_config clip_vit_large(1, 16, 23, 2048, 1024, 4096, 0, 1, 0, 336, 14, 768, 4096); // llava's and vila's clip model uses only 23 layers out of 24
-const struct model_config mistral_7B(1, 32, 32, 2048, 4096, 11008, 32000, 1, 1e-6);
+const struct model_config mistral_7B(1, 32, 8, 32, 2048, 4096, 14336, 32000, 1, 1e-5);
+const struct model_config llama_3_8B(1, 32, 8, 32, 2048, 4096, 14336, 128256, 1, 1e-5);
static struct model_config get_opt_model_config(int choise) {
struct model_config ret;
@@ -115,6 +130,12 @@ static struct model_config get_opt_model_config(int choise) {
case Mistral_7B:
ret = mistral_7B;
break;
+ case LLaMA_3_8B:
+ ret = llama_3_8B;
+ break;
+ case VILA1_5_8B:
+ ret = vila_7B;
+ break;
default:
throw("Unsupported model choice.");
break;
diff --git a/llm/include/nn_modules/Int4llamaAttention.h b/llm/include/nn_modules/Int4llamaAttention.h
index 316cd975..41ea7806 100644
--- a/llm/include/nn_modules/Int4llamaAttention.h
+++ b/llm/include/nn_modules/Int4llamaAttention.h
@@ -71,7 +71,7 @@ class Int4llamaAttention {
private:
std::string profile_name = "Int4llamaAttention";
- int embed_dim, num_heads, head_dim;
+ int embed_dim, num_heads, num_kv_heads, head_dim;
#ifdef QM_CUDA
Linear_half_int4 o_proj, qkv_proj;
RotaryPosEmb_cuda rotary_pos_emb;
@@ -81,9 +81,11 @@ class Int4llamaAttention {
Linear_FP_int4 k_proj, v_proj, q_proj, o_proj, qkv_proj;
RotaryPosEmb rotary_pos_emb;
BMM_F32T qk_bmm, pv_bmm;
- void unshape(Matrix3D shaped, Matrix3D unshape, int sqlen);
- void shape(Matrix3D unshape, Matrix3D shaped, int sqlen);
+ void unshape(Matrix3D shaped, Matrix3D unshape, int num_heads, int head_dim, int sqlen);
+ void shape(Matrix3D unshape, Matrix3D shaped, int num_heads, int head_dim, int sqlen);
void shape_qkv(Matrix3D unshape, Matrix3D shaped_q, Matrix3D shaped_k,
Matrix3D shaped_v, int sqlen);
+ void repeat(Matrix3D input, Matrix3D output, int num_heads, int num_kv_heads, int sqlen, int head_dim);
+
#endif
};
diff --git a/llm/mistral b/llm/mistral
index 0efb0a2a..9d13c143 100755
--- a/llm/mistral
+++ b/llm/mistral
@@ -1,2 +1,2 @@
# !/bin/bash
-./chat Mistral_7B INT4 5
+./chat Mistral_7B INT4 5 0
diff --git a/llm/models/llama3_vocab.bin b/llm/models/llama3_vocab.bin
new file mode 100644
index 00000000..eccc1874
Binary files /dev/null and b/llm/models/llama3_vocab.bin differ
diff --git a/llm/models/mistral_vocab.bin b/llm/models/mistral_vocab.bin
new file mode 100644
index 00000000..ba60e8ff
Binary files /dev/null and b/llm/models/mistral_vocab.bin differ
diff --git a/llm/src/nn_modules/MistralGenerate.cc b/llm/src/nn_modules/MistralGenerate.cc
new file mode 100644
index 00000000..df775734
--- /dev/null
+++ b/llm/src/nn_modules/MistralGenerate.cc
@@ -0,0 +1,243 @@
+#include
+#include
+#include
+
+#include "Generate.h"
+#include "LLaMATokenizer.h"
+#include "common.h"
+#include "utils.h"
+#include "interface.h"
+
+// Function to speak in the background
+static void sayInBackground(const std::string& text) {
+ std::string command = "./application/sts_utils/speak \"" + text + "\"";
+ int result = std::system(command.c_str());
+ (void)result;
+}
+
+std::string MistralGenerate(std::string param_path, void *model_ptr, int model_type, std::string text, const struct opt_params generation_config,
+ std::string voc_path, bool interactive, bool voicechat) {
+ std::vector last_n_tokens(generation_config.n_ctx);
+ std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0);
+ std::vector embd;
+ std::vector generate_ids;
+
+ const int max_token = 2048;
+ std::vector input_ids(max_token);
+ llama_vocab vocab = llama_init_vocab(voc_path.c_str());
+ const int n = llama_tokenize(vocab, text.c_str(), input_ids.data(), input_ids.size(), true);
+ input_ids.resize(n);
+
+ int n_consumed = 0;
+ while ((int)input_ids.size() > n_consumed) {
+ embd.push_back(input_ids[n_consumed]);
+ last_n_tokens.erase(last_n_tokens.begin());
+ last_n_tokens.push_back(input_ids[n_consumed]);
+ ++n_consumed;
+
+ if ((int)embd.size() >= generation_config.n_batch) {
+ break;
+ }
+ }
+ // if (interactive) std::cout << "ASSISTANT: " << std::endl;
+
+ int break_cnt = 2;
+ bool new_prompt = true;
+ static bool has_past_kv = false;
+ static std::vector> past_keys, past_values;
+ int n_remain = generation_config.n_predict;
+ std::string output;
+ while (n_remain != 0 && break_cnt) {
+ std::vector logits(generation_config.n_vocab);
+
+ int sqlen = 1;
+ if (new_prompt) {
+ sqlen = input_ids.size();
+ }
+ if (model_type == LLaMA_INT4) {
+ Int4LlamaForCausalLM *model = static_cast(model_ptr);
+ struct Int4LlamaForCausalLM_output model_output;
+ struct Int4LlamaForCausalLM_input model_input;
+ if (has_past_kv) {
+ Matrix3D input_ids_mat(input_ids.data(), 1, 1, sqlen);
+ model_input = {input_ids_mat, past_keys, past_values};
+ } else {
+ Matrix3D input_ids_mat(input_ids.data(), 1, 1, sqlen);
+ model_input = {input_ids_mat};
+ }
+ if (!new_prompt) STATS_START("Inference latency");
+ model_output = model->forward(param_path, model_input);
+ if (!new_prompt) STATS_END("Inference latency");
+ past_keys = model_output.past_keys;
+ past_values = model_output.past_values;
+ // memcpy model_ouput.logits[-1] to logits
+ memcpy(logits.data(), &model_output.logits.m_data[(sqlen - 1) * generation_config.n_vocab],
+ generation_config.n_vocab * sizeof(float));
+ } else if (model_type == LLaMA_FP32) {
+ Fp32LlamaForCausalLM *model = static_cast(model_ptr);
+ struct Fp32LlamaForCausalLM_output model_output;
+ struct Fp32LlamaForCausalLM_input model_input;
+ if (has_past_kv) {
+ Matrix3D input_ids_mat(input_ids.data(), 1, 1, sqlen);
+ model_input = {input_ids_mat, past_keys, past_values};
+ } else {
+ Matrix3D input_ids_mat(input_ids.data(), 1, 1, sqlen);
+ model_input = {input_ids_mat};
+ }
+ if (!new_prompt) STATS_START("Inference latency");
+ model_output = model->forward(model_input);
+ if (!new_prompt) STATS_END("Inference latency");
+ past_keys = model_output.past_keys;
+ past_values = model_output.past_values;
+ // memcpy model_ouput.logits[-1] to logits
+ memcpy(logits.data(), &model_output.logits.m_data[(sqlen - 1) * generation_config.n_vocab],
+ generation_config.n_vocab * sizeof(float));
+ }
+ has_past_kv = true;
+
+ // Generate
+ const int n_ctx = generation_config.n_ctx;
+ const float temp = generation_config.temp;
+ const int32_t top_k = generation_config.top_k <= 0 ? generation_config.n_vocab : generation_config.top_k;
+ const float top_p = generation_config.top_p;
+ const float tfs_z = generation_config.tfs_z;
+ const float typical_p = generation_config.typical_p;
+ const int32_t repeat_last_n = generation_config.repeat_last_n < 0 ? n_ctx : generation_config.repeat_last_n;
+ const float repeat_penalty = generation_config.repeat_penalty;
+ const float alpha_presence = generation_config.presence_penalty;
+ const float alpha_frequency = generation_config.frequency_penalty;
+ const int mirostat = generation_config.mirostat;
+ const float mirostat_tau = generation_config.mirostat_tau;
+ const float mirostat_eta = generation_config.mirostat_eta;
+ const int n_vocab = generation_config.n_vocab;
+
+ std::vector candidates;
+ candidates.reserve(n_vocab);
+ for (int token_id = 0; token_id < n_vocab; token_id++) {
+ candidates.emplace_back(OPT_token_data{token_id, logits[token_id], 0.0f});
+ }
+
+ OPT_token_data_array candidates_p = {candidates.data(), candidates.size(), false};
+
+ // Apply penalties
+ auto last_n_repeat = std::min(std::min((int)last_n_tokens.size(), repeat_last_n), n_ctx);
+ sample_repetition_penalty(&candidates_p, last_n_tokens.data() + last_n_tokens.size() - last_n_repeat,
+ last_n_repeat, repeat_penalty);
+ sample_frequency_and_presence_penalties(&candidates_p,
+ last_n_tokens.data() + last_n_tokens.size() - last_n_repeat,
+ last_n_repeat, alpha_frequency, alpha_presence);
+
+ int id = 0;
+ if (temp <= 0) {
+ id = sample_token_greedy(&candidates_p);
+ } else {
+ if (mirostat == 1) {
+ static float mirostat_mu = 2.0f * mirostat_tau;
+ const int mirostat_m = 100;
+ sample_temperature(&candidates_p, temp);
+ id =
+ sample_token_mirostat(n_vocab, &candidates_p, mirostat_tau, mirostat_eta, mirostat_m, &mirostat_mu);
+ } else if (mirostat == 2) {
+ static float mirostat_mu = 2.0f * mirostat_tau;
+ sample_temperature(&candidates_p, temp);
+ id = sample_token_mirostat_v2(&candidates_p, mirostat_tau, mirostat_eta, &mirostat_mu);
+ } else {
+ // Temperature sampling
+ sample_top_k(&candidates_p, top_k, 1);
+ sample_tail_free(&candidates_p, tfs_z, 1);
+ sample_typical(&candidates_p, typical_p, 1);
+ sample_top_p(&candidates_p, top_p, 1);
+ sample_temperature(&candidates_p, temp);
+ id = sample_token(&candidates_p);
+ }
+ }
+
+ if (id == 2) {
+ break_cnt--;
+ continue;
+ } // eos
+ else if (id == 1)
+ continue;
+ break_cnt = 2;
+
+ bool skip = false;
+ if (id == 27332) { // token = ###
+ break_cnt = 0;
+ skip = true;
+ }
+
+
+ last_n_tokens.erase(last_n_tokens.begin());
+ last_n_tokens.push_back(id);
+ embd.push_back(id);
+ generate_ids.push_back(id);
+ input_ids = std::vector{id};
+
+ if (interactive && !skip) {
+ output += llama_id_to_token(vocab, id);
+ std::cout << llama_id_to_token(vocab, id) << std::flush;
+ if (voicechat) {
+ // Remove quotes
+ output.erase(std::remove(output.begin(), output.end(), '\"'), output.end());
+ // Remove hashtags
+ output.erase(std::remove(output.begin(), output.end(), '#'), output.end());
+ // Remove dashes
+ std::replace(output.begin(), output.end(), '-', ' ');
+
+ size_t lastPos;
+ // starts ealier but slows down dictation
+ bool ended = false;
+ if (output.find(", ") != std::string::npos){
+ lastPos = output.rfind(',');
+ ended = true;
+ }
+ if (output.find("\n") != std::string::npos){
+ lastPos = output.rfind('\n');
+ ended = true;
+ }
+ else if (output.find(". ") != std::string::npos){
+ lastPos = output.rfind('.');
+ ended = true;
+ }
+ else if (output.find("! ") != std::string::npos){
+ lastPos = output.rfind('!');
+ ended = true;
+ }
+ else if (output.find("? ") != std::string::npos){
+ lastPos = output.rfind('?');
+ ended = true;
+
+ }
+ else if (output.find(": ") != std::string::npos){
+ lastPos = output.rfind(':');
+ ended = true;
+ }
+ if (ended){
+ // Extract sentence 1 (up to and including the last period)
+ std::string output_copy = output.substr(0, lastPos + 1);
+ // Extract beginning of sentence 2 (excluding the space after the last period)
+ output = output.substr(lastPos + 1); // Skip the last period and space
+ std::thread sayThread(sayInBackground, output_copy);
+ sayThread.detach();
+ }
+ }
+ }
+
+ new_prompt = false;
+ --n_remain;
+ }
+ if (voicechat && interactive){
+ sayInBackground(output);
+ }
+
+ if (interactive) std::cout << std::endl;
+
+ // Set prompt color
+ set_print_yellow();
+ Profiler::getInstance().report_internal();
+ Profiler::getInstance().reset();
+ // Reset color
+ set_print_reset();
+
+ return output;
+}
diff --git a/llm/src/nn_modules/non_cuda/Int4llamaAttention.cc b/llm/src/nn_modules/non_cuda/Int4llamaAttention.cc
index 0505af12..b8fb9d42 100644
--- a/llm/src/nn_modules/non_cuda/Int4llamaAttention.cc
+++ b/llm/src/nn_modules/non_cuda/Int4llamaAttention.cc
@@ -22,6 +22,8 @@ static float *query_states_unshape_arr;
static float *key_states_unshape_arr;
static float *value_states_unshape_arr;
// static float *qkv_states_unshape_arr;
+static float *final_key_states_arr;
+static float *final_value_states_arr;
#if DEC_SHARED_MEM
static uint8_t *q_weight, *k_weight, *v_weight, *o_weight;
@@ -60,6 +62,9 @@ void Int4llamaAttention::initialized_memory(const struct model_config config) {
allocate_aligned_memory(query_states_unshape_arr, config.max_sqlen * config.embed_dim * sizeof(float));
allocate_aligned_memory(key_states_unshape_arr, config.max_sqlen * config.embed_dim * sizeof(float));
allocate_aligned_memory(value_states_unshape_arr, config.max_sqlen * config.embed_dim * sizeof(float));
+
+ allocate_aligned_memory(final_key_states_arr, config.max_sqlen * config.embed_dim * sizeof(float));
+ allocate_aligned_memory(final_value_states_arr, config.max_sqlen * config.embed_dim * sizeof(float));
// allocate_aligned_memory(qkv_states_unshape_arr, config.max_sqlen * config.embed_dim * 3 * sizeof(float));
#if DEC_SHARED_MEM
@@ -87,19 +92,19 @@ void Int4llamaAttention::initialized_memory(const struct model_config config) {
#endif
}
-inline void Int4llamaAttention::shape(Matrix3D unshape, Matrix3D shaped, int sqlen) {
+inline void Int4llamaAttention::shape(Matrix3D unshape, Matrix3D shaped, int num_heads, int head_dim, int sqlen) {
PROFILE_START("Int4llamaAttention::shape");
assert(unshape.m_dim_x == 1); // bsz == 1
assert(unshape.m_dim_y == sqlen);
- assert(unshape.m_dim_z == this->num_heads * this->head_dim);
- assert(shaped.m_dim_x == this->num_heads);
+ assert(unshape.m_dim_z == num_heads * head_dim);
+ assert(shaped.m_dim_x == num_heads);
assert(shaped.m_dim_y == sqlen);
- assert(shaped.m_dim_z == this->head_dim);
+ assert(shaped.m_dim_z == head_dim);
- for (int i = 0; i < this->num_heads; i++) {
+ for (int i = 0; i < num_heads; i++) {
for (int j = 0; j < sqlen; j++) {
- for (int k = 0; k < this->head_dim; k++) {
- shaped(i, j, k) = unshape(0, j, i * this->head_dim + k);
+ for (int k = 0; k < head_dim; k++) {
+ shaped(i, j, k) = unshape(0, j, i * head_dim + k);
}
}
}
@@ -139,39 +144,69 @@ inline void Int4llamaAttention::shape(Matrix3D unshape, Matrix3D s
// PROFILE_END("Int4llamaAttention::shape_qkv");
// }
-inline void Int4llamaAttention::unshape(Matrix3D shaped, Matrix3D unshape, int sqlen) {
+inline void Int4llamaAttention::unshape(Matrix3D shaped, Matrix3D unshape, int num_heads, int head_dim, int sqlen) {
PROFILE_START("Int4llamaAttention::unshape");
assert(unshape.m_dim_x == 1); // bsz == 1
assert(unshape.m_dim_y == sqlen);
- assert(unshape.m_dim_z == this->num_heads * this->head_dim);
- assert(shaped.m_dim_x == this->num_heads);
+ assert(unshape.m_dim_z == num_heads * head_dim);
+ assert(shaped.m_dim_x == num_heads);
assert(shaped.m_dim_y == sqlen);
- assert(shaped.m_dim_z == this->head_dim);
+ assert(shaped.m_dim_z == head_dim);
- for (int i = 0; i < this->num_heads; i++) {
+ for (int i = 0; i < num_heads; i++) {
for (int j = 0; j < sqlen; j++) {
- for (int k = 0; k < this->head_dim; k++) {
- unshape(0, j, i * this->head_dim + k) = shaped(i, j, k);
+ for (int k = 0; k < head_dim; k++) {
+ unshape(0, j, i * head_dim + k) = shaped(i, j, k);
}
}
}
PROFILE_END("Int4llamaAttention::unshape");
}
+inline void Int4llamaAttention::repeat(Matrix3D input, Matrix3D output, int num_heads, int num_kv_heads, int sqlen, int head_dim) {
+ PROFILE_START("Int4llamaAttention::repeat");
+ int n_repeat = num_heads / num_kv_heads;
+ assert(input.m_dim_x == num_kv_heads);
+ assert(input.m_dim_y == sqlen);
+ assert(input.m_dim_z == head_dim);
+ assert(output.m_dim_x == num_heads);
+ assert(output.m_dim_y == sqlen);
+ assert(output.m_dim_z == head_dim);
+
+ for (int i = 0; i < num_heads; i++) {
+ for (int j = 0; j < sqlen; j++) {
+ for (int k = 0; k < head_dim; k++) {
+ output(i, j, k) = input(i / n_repeat, j, k);
+ }
+ }
+ }
+ PROFILE_END("Int4llamaAttention::repeat");
+}
+
Int4llamaAttention::Int4llamaAttention(std::string param_path, const struct model_config config, int layer_idx) {
+ this->embed_dim = config.embed_dim;
+ this->num_heads = config.num_heads;
+ this->num_kv_heads = config.num_kv_heads;
+ assert(config.embed_dim % config.num_heads == 0);
+ this->head_dim = config.embed_dim / config.num_heads;
+
#if !(DEC_SHARED_MEM)
uint8_t *q_weight, *k_weight, *v_weight, *o_weight, *qkv_weight;
allocate_aligned_memory(q_weight, (config.embed_dim * config.embed_dim * sizeof(uint8_t)) / 2);
- allocate_aligned_memory(k_weight, (config.embed_dim * config.embed_dim * sizeof(uint8_t)) / 2);
- allocate_aligned_memory(v_weight, (config.embed_dim * config.embed_dim * sizeof(uint8_t)) / 2);
+ // allocate_aligned_memory(k_weight, (config.embed_dim * config.embed_dim * sizeof(uint8_t)) / 2);
+ // allocate_aligned_memory(v_weight, (config.embed_dim * config.embed_dim * sizeof(uint8_t)) / 2);
+ allocate_aligned_memory(k_weight, (config.embed_dim * config.num_kv_heads * this->head_dim * sizeof(uint8_t)) / 2);
+ allocate_aligned_memory(v_weight, (config.embed_dim * config.num_kv_heads * this->head_dim * sizeof(uint8_t)) / 2);
allocate_aligned_memory(o_weight, (config.embed_dim * config.embed_dim * sizeof(uint8_t)) / 2);
// allocate_aligned_memory(qkv_weight, (config.embed_dim * config.embed_dim * 3 * sizeof(uint8_t)) / 2);
this->q_proj =
Linear_FP_int4(Matrix3D(q_weight, 1, config.embed_dim, config.embed_dim / 2), param_path + "/q_proj");
this->k_proj =
- Linear_FP_int4(Matrix3D(k_weight, 1, config.embed_dim, config.embed_dim / 2), param_path + "/k_proj");
+ // Linear_FP_int4(Matrix3D(k_weight, 1, config.embed_dim, config.embed_dim / 2), param_path + "/k_proj");
+ Linear_FP_int4(Matrix3D(k_weight, 1, config.num_kv_heads * this->head_dim, config.embed_dim / 2), param_path + "/k_proj");
this->v_proj =
- Linear_FP_int4(Matrix3D(v_weight, 1, config.embed_dim, config.embed_dim / 2), param_path + "/v_proj");
+ // Linear_FP_int4(Matrix3D(v_weight, 1, config.embed_dim, config.embed_dim / 2), param_path + "/v_proj");
+ Linear_FP_int4(Matrix3D(v_weight, 1, config.num_kv_heads * this->head_dim, config.embed_dim / 2), param_path + "/v_proj");
this->o_proj =
Linear_FP_int4(Matrix3D(o_weight, 1, config.embed_dim, config.embed_dim / 2), param_path + "/o_proj");
// this->qkv_proj =
@@ -190,11 +225,6 @@ Int4llamaAttention::Int4llamaAttention(std::string param_path, const struct mode
read_to_array((param_path + "/qk_bmm/alpha.bin").c_str(), &qk_bmm_alpha, 1);
this->qk_bmm = BMM_F32T(qk_bmm_alpha);
this->pv_bmm = BMM_F32T(1.0f);
-
- this->embed_dim = config.embed_dim;
- this->num_heads = config.num_heads;
- assert(config.embed_dim % config.num_heads == 0);
- this->head_dim = config.embed_dim / config.num_heads;
}
struct transpose_1_2idx_float_arg {
@@ -292,7 +322,7 @@ struct Int4llamaAttention_output Int4llamaAttention::forward(std::string param_p
this->q_proj.forward(input.hidden_states, query_states_unshape);
PROFILE_END(profile_name + "::q_proj");
Matrix3D query_states(query_states_arr, this->num_heads, sqlen, this->head_dim);
- this->shape(query_states_unshape, query_states, sqlen);
+ this->shape(query_states_unshape, query_states, this->num_heads, this->head_dim, sqlen);
// Get the memory buffer
float *ret_value_states, *ret_key_states;
@@ -307,20 +337,20 @@ struct Int4llamaAttention_output Int4llamaAttention::forward(std::string param_p
}
// Key states
- Matrix3D key_states_unshape(key_states_unshape_arr, b, sqlen, embed_dim);
+ Matrix3D key_states_unshape(key_states_unshape_arr, b, sqlen, num_kv_heads * head_dim);
PROFILE_START(profile_name + "::k_proj");
this->k_proj.forward(input.hidden_states, key_states_unshape);
PROFILE_END(profile_name + "::k_proj");
- Matrix3D key_states(key_states_arr, this->num_heads, sqlen, this->head_dim);
- this->shape(key_states_unshape, key_states, sqlen);
+ Matrix3D key_states(key_states_arr, this->num_kv_heads, sqlen, this->head_dim);
+ this->shape(key_states_unshape, key_states, this->num_kv_heads, this->head_dim, sqlen);
// Value states
- Matrix3D value_states_unshape(value_states_unshape_arr, b, sqlen, embed_dim);
+ Matrix3D value_states_unshape(value_states_unshape_arr, b, sqlen, num_kv_heads * head_dim);
PROFILE_START(profile_name + "::v_proj");
this->v_proj.forward(input.hidden_states, value_states_unshape);
PROFILE_END(profile_name + "::v_proj");
- Matrix3D value_states(value_states_arr, this->num_heads, sqlen, this->head_dim);
- this->shape(value_states_unshape, value_states, sqlen);
+ Matrix3D value_states(value_states_arr, this->num_kv_heads, sqlen, this->head_dim);
+ this->shape(value_states_unshape, value_states, this->num_kv_heads, this->head_dim, sqlen);
// Rotate position
int start_idx = 0;
@@ -349,13 +379,19 @@ struct Int4llamaAttention_output Int4llamaAttention::forward(std::string param_p
}
} else {
// Put the data into the buffer
- memcpy(ret_value_states, value_states_arr, (this->num_heads * tgz * this->head_dim) * sizeof(float));
- memcpy(ret_key_states, key_states_arr, (this->num_heads * tgz * this->head_dim) * sizeof(float));
+ memcpy(ret_value_states, value_states_arr, (this->num_kv_heads * tgz * this->head_dim) * sizeof(float));
+ memcpy(ret_key_states, key_states_arr, (this->num_kv_heads * tgz * this->head_dim) * sizeof(float));
}
- Matrix3D final_value_states(ret_value_states, this->num_heads, tgz, this->head_dim);
- Matrix3D final_key_states(ret_key_states, this->num_heads, tgz, this->head_dim);
+ Matrix3D final_GQA_value_states(ret_value_states, this->num_kv_heads, tgz, this->head_dim);
+ Matrix3D final_GQA_key_states(ret_key_states, this->num_kv_heads, tgz, this->head_dim);
PROFILE_END(profile_name + "::cat_past_keys_values");
+ // Repeat KV
+ Matrix3D final_value_states(final_value_states_arr, this->num_heads, tgz, this->head_dim);
+ Matrix3D final_key_states(final_key_states_arr, this->num_heads, tgz, this->head_dim);
+ this->repeat(final_GQA_value_states, final_value_states, this->num_heads, this->num_kv_heads, tgz, this->head_dim);
+ this->repeat(final_GQA_key_states, final_key_states, this->num_heads, this->num_kv_heads, tgz, this->head_dim);
+
// QK_BMM
Matrix3D attn_weights(attn_weights_arr, this->num_heads, sqlen, tgz);
PROFILE_START(profile_name + "::qk_bmm");
@@ -389,7 +425,7 @@ struct Int4llamaAttention_output Int4llamaAttention::forward(std::string param_p
PROFILE_END(profile_name + "::pv_bmm");
Matrix3D attn_output_transpose(attn_output_transpose_arr, 1, sqlen, this->num_heads * this->head_dim);
- this->unshape(attn_output, attn_output_transpose, sqlen);
+ this->unshape(attn_output, attn_output_transpose, this->num_heads, this->head_dim, sqlen);
// Output projection
Matrix3D attn_output_fp(attn_output_fp_arr, 1, sqlen, this->num_heads * this->head_dim);
@@ -399,7 +435,7 @@ struct Int4llamaAttention_output Int4llamaAttention::forward(std::string param_p
// Output assignment
output.attn_output = attn_output_fp;
- output.past_key_value = {final_key_states, final_value_states};
+ output.past_key_value = {final_GQA_key_states, final_GQA_value_states};
PROFILE_END(profile_name);
return output;
diff --git a/llm/src/nn_modules/non_cuda/LLaMA3Generate.cc b/llm/src/nn_modules/non_cuda/LLaMA3Generate.cc
new file mode 100644
index 00000000..43803a60
--- /dev/null
+++ b/llm/src/nn_modules/non_cuda/LLaMA3Generate.cc
@@ -0,0 +1,477 @@
+#include
+#include
+#include
+#include
+
+#include "Generate.h"
+#include "LLaMATokenizer.h"
+#include "common.h"
+#include "utils.h"
+#include "interface.h"
+
+// Function to speak in the background
+static void sayInBackground(const std::string& text) {
+ std::string command = "./application/sts_utils/speak \"" + text + "\"";
+ int result = std::system(command.c_str());
+ (void)result;
+}
+
+typedef struct {
+ char *str;
+ int id;
+} TokenIndex;
+
+typedef struct {
+ char** vocab;
+ float* vocab_scores;
+ TokenIndex *sorted_vocab;
+ int vocab_size;
+ unsigned int max_token_length;
+ unsigned char byte_pieces[512]; // stores all single-byte strings
+} Tokenizer;
+
+int compare_tokens(const void *a, const void *b);
+void build_tokenizer(Tokenizer* t, const char* tokenizer_path, int vocab_size);
+void free_tokenizer(Tokenizer* t);
+char* decode(Tokenizer* t, int token);
+int str_lookup(char *str, TokenIndex *sorted_vocab, int vocab_size);
+void encode(Tokenizer* t, const char *text, int8_t bos, int8_t eos, int *tokens, int *n_tokens);
+
+std::string LLaMA3Generate(std::string param_path, void *model_ptr, int model_type, std::string text, const struct opt_params generation_config,
+ std::string voc_path, bool interactive, bool voicechat) {
+ std::vector last_n_tokens(generation_config.n_ctx);
+ std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0);
+ std::vector embd;
+ std::vector generate_ids;
+
+ // llama_vocab vocab = llama_init_vocab(voc_path.c_str());
+ // build the Tokenizer via the tokenizer .bin file
+ Tokenizer tokenizer;
+ build_tokenizer(&tokenizer, voc_path.c_str(), generation_config.n_vocab);
+
+ const int max_token = 2048;
+ std::vector input_ids(max_token);
+ // const int n = llama_tokenize(vocab, text.c_str(), input_ids.data(), input_ids.size(), true);
+ int num_tokens = 0;
+ // encode the user prompt into tokens
+ encode(&tokenizer, text.c_str(), 0, 0, input_ids.data(), &num_tokens);
+ input_ids.resize(num_tokens);
+
+ int n_consumed = 0;
+ while ((int)input_ids.size() > n_consumed) {
+ embd.push_back(input_ids[n_consumed]);
+ last_n_tokens.erase(last_n_tokens.begin());
+ last_n_tokens.push_back(input_ids[n_consumed]);
+ ++n_consumed;
+
+ if ((int)embd.size() >= generation_config.n_batch) {
+ break;
+ }
+ }
+
+ int break_cnt = 2;
+ bool new_prompt = true;
+ static bool has_past_kv = false;
+ static std::vector> past_keys, past_values;
+ int n_remain = generation_config.n_predict;
+ std::string output;
+ while (n_remain != 0 && break_cnt) {
+ std::vector logits(generation_config.n_vocab);
+
+ int sqlen = 1;
+ if (new_prompt) {
+ sqlen = input_ids.size();
+ }
+ if (model_type == LLaMA_INT4) {
+ Int4LlamaForCausalLM *model = static_cast(model_ptr);
+ struct Int4LlamaForCausalLM_output model_output;
+ struct Int4LlamaForCausalLM_input model_input;
+ if (has_past_kv) {
+ Matrix3D input_ids_mat(input_ids.data(), 1, 1, sqlen);
+ model_input = {input_ids_mat, past_keys, past_values};
+ } else {
+ Matrix3D input_ids_mat(input_ids.data(), 1, 1, sqlen);
+ model_input = {input_ids_mat};
+ }
+ if (!new_prompt) STATS_START("Inference latency");
+ model_output = model->forward(param_path, model_input);
+ if (!new_prompt) STATS_END("Inference latency");
+ past_keys = model_output.past_keys;
+ past_values = model_output.past_values;
+ // memcpy model_ouput.logits[-1] to logits
+ memcpy(logits.data(), &model_output.logits.m_data[(sqlen - 1) * generation_config.n_vocab],
+ generation_config.n_vocab * sizeof(float));
+ } else if (model_type == LLaMA_FP32) {
+ Fp32LlamaForCausalLM *model = static_cast(model_ptr);
+ struct Fp32LlamaForCausalLM_output model_output;
+ struct Fp32LlamaForCausalLM_input model_input;
+ if (has_past_kv) {
+ Matrix3D input_ids_mat(input_ids.data(), 1, 1, sqlen);
+ model_input = {input_ids_mat, past_keys, past_values};
+ } else {
+ Matrix3D input_ids_mat(input_ids.data(), 1, 1, sqlen);
+ model_input = {input_ids_mat};
+ }
+ if (!new_prompt) STATS_START("Inference latency");
+ model_output = model->forward(model_input);
+ if (!new_prompt) STATS_END("Inference latency");
+ past_keys = model_output.past_keys;
+ past_values = model_output.past_values;
+ // memcpy model_ouput.logits[-1] to logits
+ memcpy(logits.data(), &model_output.logits.m_data[(sqlen - 1) * generation_config.n_vocab],
+ generation_config.n_vocab * sizeof(float));
+ }
+ has_past_kv = true;
+
+ // Generate
+ const int n_ctx = generation_config.n_ctx;
+ const float temp = generation_config.temp;
+ const int32_t top_k = generation_config.top_k <= 0 ? generation_config.n_vocab : generation_config.top_k;
+ const float top_p = generation_config.top_p;
+ const float tfs_z = generation_config.tfs_z;
+ const float typical_p = generation_config.typical_p;
+ const int32_t repeat_last_n = generation_config.repeat_last_n < 0 ? n_ctx : generation_config.repeat_last_n;
+ const float repeat_penalty = generation_config.repeat_penalty;
+ const float alpha_presence = generation_config.presence_penalty;
+ const float alpha_frequency = generation_config.frequency_penalty;
+ const int mirostat = generation_config.mirostat;
+ const float mirostat_tau = generation_config.mirostat_tau;
+ const float mirostat_eta = generation_config.mirostat_eta;
+ const int n_vocab = generation_config.n_vocab;
+
+ std::vector candidates;
+ candidates.reserve(n_vocab);
+ for (int token_id = 0; token_id < n_vocab; token_id++) {
+ candidates.emplace_back(OPT_token_data{token_id, logits[token_id], 0.0f});
+ }
+
+ OPT_token_data_array candidates_p = {candidates.data(), candidates.size(), false};
+
+ // Apply penalties
+ auto last_n_repeat = std::min(std::min((int)last_n_tokens.size(), repeat_last_n), n_ctx);
+ sample_repetition_penalty(&candidates_p, last_n_tokens.data() + last_n_tokens.size() - last_n_repeat,
+ last_n_repeat, repeat_penalty);
+ sample_frequency_and_presence_penalties(&candidates_p,
+ last_n_tokens.data() + last_n_tokens.size() - last_n_repeat,
+ last_n_repeat, alpha_frequency, alpha_presence);
+
+ int id = 0;
+ if (temp <= 0) {
+ id = sample_token_greedy(&candidates_p);
+ } else {
+ if (mirostat == 1) {
+ static float mirostat_mu = 2.0f * mirostat_tau;
+ const int mirostat_m = 100;
+ sample_temperature(&candidates_p, temp);
+ id =
+ sample_token_mirostat(n_vocab, &candidates_p, mirostat_tau, mirostat_eta, mirostat_m, &mirostat_mu);
+ } else if (mirostat == 2) {
+ static float mirostat_mu = 2.0f * mirostat_tau;
+ sample_temperature(&candidates_p, temp);
+ id = sample_token_mirostat_v2(&candidates_p, mirostat_tau, mirostat_eta, &mirostat_mu);
+ } else {
+ // Temperature sampling
+ sample_top_k(&candidates_p, top_k, 1);
+ sample_tail_free(&candidates_p, tfs_z, 1);
+ sample_typical(&candidates_p, typical_p, 1);
+ sample_top_p(&candidates_p, top_p, 1);
+ sample_temperature(&candidates_p, temp);
+ id = sample_token(&candidates_p);
+ }
+ }
+
+ if (id == 128001) {
+ break_cnt--;
+ continue;
+ } // eos
+ else if (id == 128000)
+ continue;
+ break_cnt = 2;
+
+ bool skip = false;
+ if (id == 35075 or id == 128009) { // token = "Human" or "|eot_id|"
+ break_cnt = 0;
+ skip = true;
+ }
+
+ last_n_tokens.erase(last_n_tokens.begin());
+ last_n_tokens.push_back(id);
+ embd.push_back(id);
+ generate_ids.push_back(id);
+ input_ids = std::vector{id};
+
+ if (interactive && !skip) {
+ // output += llama_id_to_token(vocab, id);
+ // std::cout << llama_id_to_token(vocab, id) << std::flush;
+ output += decode(&tokenizer, id);
+ std::cout << decode(&tokenizer, id) << std::flush;
+
+ if (voicechat) {
+ // Remove quotes
+ output.erase(std::remove(output.begin(), output.end(), '\"'), output.end());
+ // Remove hashtags
+ output.erase(std::remove(output.begin(), output.end(), '#'), output.end());
+ // Remove dashes
+ std::replace(output.begin(), output.end(), '-', ' ');
+
+ size_t lastPos;
+ // starts ealier but slows down dictation
+ bool ended = false;
+ if (output.find(", ") != std::string::npos){
+ lastPos = output.rfind(',');
+ ended = true;
+ }
+ if (output.find("\n") != std::string::npos){
+ lastPos = output.rfind('\n');
+ ended = true;
+ }
+ else if (output.find(". ") != std::string::npos){
+ lastPos = output.rfind('.');
+ ended = true;
+ }
+ else if (output.find("! ") != std::string::npos){
+ lastPos = output.rfind('!');
+ ended = true;
+ }
+ else if (output.find("? ") != std::string::npos){
+ lastPos = output.rfind('?');
+ ended = true;
+
+ }
+ else if (output.find(": ") != std::string::npos){
+ lastPos = output.rfind(':');
+ ended = true;
+ }
+ if (ended){
+ // Extract sentence 1 (up to and including the last period)
+ std::string output_copy = output.substr(0, lastPos + 1);
+ // Extract beginning of sentence 2 (excluding the space after the last period)
+ output = output.substr(lastPos + 1); // Skip the last period and space
+ std::thread sayThread(sayInBackground, output_copy);
+ sayThread.detach();
+ }
+ }
+ }
+
+ new_prompt = false;
+ --n_remain;
+ }
+ if (voicechat && interactive){
+ sayInBackground(output);
+ }
+
+ if (interactive) std::cout << std::endl;
+
+ // Set prompt color
+ set_print_yellow();
+ Profiler::getInstance().report_internal();
+ Profiler::getInstance().reset();
+ // Reset color
+ set_print_reset();
+
+ // Free the tokenizer
+ free_tokenizer(&tokenizer);
+
+ return output;
+}
+
+
+// Adapted from llama3.c: https://github.com/jameswdelancey/llama3.c
+// ----------------------------------------------------------------------------
+// The Byte Pair Encoding (BPE) Tokenizer that translates strings <-> tokens
+int compare_tokens(const void *a, const void *b) {
+ return strcmp(((TokenIndex*)a)->str, ((TokenIndex*)b)->str);
+}
+
+void build_tokenizer(Tokenizer* t, const char* tokenizer_path, int vocab_size) {
+ // i should have written the vocab_size into the tokenizer file... sigh
+ t->vocab_size = vocab_size;
+ // malloc space to hold the scores and the strings
+ t->vocab = (char**)malloc(vocab_size * sizeof(char*));
+ t->vocab_scores = (float*)malloc(vocab_size * sizeof(float));
+ t->sorted_vocab = NULL; // initialized lazily
+ for (int i = 0; i < 256; i++) {
+ t->byte_pieces[i * 2] = (unsigned char)i;
+ t->byte_pieces[i * 2 + 1] = '\0';
+ }
+ // read in the file
+ FILE *file = fopen(tokenizer_path, "rb");
+ if (!file) { fprintf(stderr, "couldn't load %s\n", tokenizer_path); exit(EXIT_FAILURE); }
+ if (fread(&t->max_token_length, sizeof(int), 1, file) != 1) { fprintf(stderr, "failed read\n"); exit(EXIT_FAILURE); }
+ int len;
+ for (int i = 0; i < vocab_size; i++) {
+ if (fread(t->vocab_scores + i, sizeof(float), 1, file) != 1) { fprintf(stderr, "failed read\n"); exit(EXIT_FAILURE);}
+ if (fread(&len, sizeof(int), 1, file) != 1) { fprintf(stderr, "failed read\n"); exit(EXIT_FAILURE); }
+ t->vocab[i] = (char *)malloc(len + 1);
+ if (fread(t->vocab[i], len, 1, file) != 1) { fprintf(stderr, "failed read\n"); exit(EXIT_FAILURE); }
+ t->vocab[i][len] = '\0'; // add the string terminating token
+ }
+ fclose(file);
+}
+
+void free_tokenizer(Tokenizer* t) {
+ for (int i = 0; i < t->vocab_size; i++) { free(t->vocab[i]); }
+ free(t->vocab);
+ free(t->vocab_scores);
+ free(t->sorted_vocab);
+}
+
+char* decode(Tokenizer* t, int token) {
+ char *piece = t->vocab[token];
+
+ // careful, some tokens designate raw bytes, and look like e.g. '<0x01>'
+ // parse this and convert and return the actual byte
+ unsigned char byte_val;
+ if (sscanf(piece, "<0x%02hhX>", &byte_val) == 1) {
+ piece = (char*)t->byte_pieces + byte_val * 2;
+ }
+ return piece;
+}
+
+int str_lookup(char *str, TokenIndex *sorted_vocab, int vocab_size) {
+ // efficiently find the perfect match for str in vocab, return its index or -1 if not found
+ TokenIndex tok = { .str = str }; // acts as the key to search for
+ TokenIndex *res = (TokenIndex *) bsearch(&tok, sorted_vocab, vocab_size, sizeof(TokenIndex), compare_tokens);
+ return res != NULL ? res->id : -1;
+}
+
+void encode(Tokenizer* t, const char *text, int8_t bos, int8_t eos, int *tokens, int *n_tokens) {
+ // encode the string text (input) into an upper-bound preallocated tokens[] array
+ // bos != 0 means prepend the BOS token (=1), eos != 0 means append the EOS token (=2)
+ if (text == NULL) { fprintf(stderr, "cannot encode NULL text\n"); exit(EXIT_FAILURE); }
+
+ if (t->sorted_vocab == NULL) {
+ // lazily malloc and sort the vocabulary
+ // t->sorted_vocab = malloc(t->vocab_size * sizeof(TokenIndex));
+ t->sorted_vocab = new TokenIndex[t->vocab_size];
+ for (int i = 0; i < t->vocab_size; i++) {
+ t->sorted_vocab[i].str = t->vocab[i];
+ t->sorted_vocab[i].id = i;
+ }
+ qsort(t->sorted_vocab, t->vocab_size, sizeof(TokenIndex), compare_tokens);
+ }
+
+ // create a temporary buffer that will store merge candidates of always two consecutive tokens
+ // *2 for concat, +1 for null terminator +2 for UTF8 (in case max_token_length is 1)
+ // char* str_buffer = malloc((t->max_token_length*2 +1 +2) * sizeof(char));
+ char* str_buffer = new char[(t->max_token_length*2 +1 +2)];
+ size_t str_len = 0;
+
+ // start at 0 tokens
+ *n_tokens = 0;
+
+ // add optional BOS (=128000) token, if desired
+ if (bos) tokens[(*n_tokens)++] = 128000;
+
+ // add_dummy_prefix is true by default
+ // so prepend a dummy prefix token to the input string, but only if text != ""
+ // TODO: pretty sure this isn't correct in the general case but I don't have the
+ // energy to read more of the sentencepiece code to figure out what it's doing
+
+
+
+
+
+ // Okay UTF-8 time. This will get messy. Here is the reference from Wikipedia:
+ // Code point ↔ UTF-8 conversion
+ // First code point Last code point Byte 1 Byte 2 Byte 3 Byte 4
+ // U+0000 U+007F 0xxxxxxx
+ // U+0080 U+07FF 110xxxxx 10xxxxxx
+ // U+0800 U+FFFF 1110xxxx 10xxxxxx 10xxxxxx
+ // U+10000 U+10FFFF 11110xxx 10xxxxxx 10xxxxxx 10xxxxxx
+
+ // process the raw (UTF-8) byte sequence of the input string
+ for (const char *c = text; *c != '\0'; c++) {
+
+ // reset buffer if the current byte is ASCII or a leading byte
+ // 0xC0 is 11000000, so (*c & 0xC0) keeps the first 2 bits and zeros the rest
+ // 0x80 is 10000000
+ // in UTF-8, all continuation bytes start with "10" in first two bits
+ // so in English this is: "if this byte is not a continuation byte"
+ if ((*c & 0xC0) != 0x80) {
+ // this byte must be either a leading byte (11...) or an ASCII char (0x...)
+ // => reset our location, as we're starting a new UTF-8 codepoint
+ str_len = 0;
+ }
+
+ // append the current byte to the buffer
+ str_buffer[str_len++] = *c; // ++ is post-increment, incremented after this line
+ str_buffer[str_len] = '\0';
+
+ // while the next character is a continuation byte, continue appending
+ // but if there are too many of them, just stop to avoid overruning str_buffer size.
+ if ((*(c+1) & 0xC0) == 0x80 && str_len < 4) {
+ continue;
+ }
+
+ // ok c+1 is not a continuation byte, so we've read in a full codepoint
+ int id = str_lookup(str_buffer, t->sorted_vocab, t->vocab_size);
+
+ if (id != -1) {
+ // we found this codepoint in vocab, add it as a token
+ tokens[(*n_tokens)++] = id;
+ } else {
+ // byte_fallback encoding: just encode each byte as a token
+ // +3 is here because the first 3 vocab elements are , ,
+ // so the individual bytes only start at index 3
+ for (int i=0; i < str_len; i++) {
+ tokens[(*n_tokens)++] = (unsigned char)str_buffer[i] + 3;
+ }
+ }
+ str_len = 0; // protect against a sequence of stray UTF8 continuation bytes
+ }
+
+ // merge the best consecutive pair or triple each iteration, according to the scores in vocab_scores
+ while (1) {
+ float best_score = -1e10;
+ int best_id = -1;
+ int best_idx = -1;
+ int best_len = 2; // length of the best merge sequence (2 for pair, 3 for triple)
+
+ // first, try to find the best pair to merge
+ for (int i = 0; i < (*n_tokens - 1); i++) {
+ // check if we can merge the pair (tokens[i], tokens[i+1])
+ snprintf(str_buffer, t->max_token_length*2 +1 +2 +1, "%s%s", t->vocab[tokens[i]], t->vocab[tokens[i+1]]);
+ int id = str_lookup(str_buffer, t->sorted_vocab, t->vocab_size);
+ if (id != -1 && t->vocab_scores[id] > best_score) {
+ // this merge pair exists in vocab! record its score and position
+ best_score = t->vocab_scores[id];
+ best_id = id;
+ best_idx = i;
+ }
+ }
+
+ // if no pair was found, try to find the best triple to merge
+ if (best_idx == -1) {
+ for (int i = 0; i < (*n_tokens - 2); i++) {
+ // check if we can merge the triple (tokens[i], tokens[i+1], tokens[i+2])
+ snprintf(str_buffer, t->max_token_length*2 +1 +2 +1, "%s%s%s", t->vocab[tokens[i]], t->vocab[tokens[i+1]], t->vocab[tokens[i+2]]);
+ int id = str_lookup(str_buffer, t->sorted_vocab, t->vocab_size);
+ if (id != -1 && t->vocab_scores[id] > best_score) {
+ // this merge triple exists in vocab! record its score and position
+ best_score = t->vocab_scores[id];
+ best_id = id;
+ best_idx = i;
+ best_len = 3;
+ }
+ }
+ }
+
+ if (best_idx == -1) {
+ break; // we couldn't find any more pairs or triples to merge, so we're done
+ }
+
+ // merge the consecutive pair or triple (best_idx, best_idx+1[, best_idx+2]) into new token best_id
+ tokens[best_idx] = best_id;
+ // delete token(s) at position best_idx+1 (and optionally best_idx+2), shift the entire sequence back
+ for (int i = best_idx + 1; i < (*n_tokens - best_len + 1); i++) {
+ tokens[i] = tokens[i + best_len - 1];
+ }
+ (*n_tokens) -= (best_len - 1); // token length decreased by the number of merged tokens minus one
+ }
+
+ // add optional EOS (=128001) token, if desired
+ if (eos) tokens[(*n_tokens)++] = 128001;
+
+ free(str_buffer);
+}
diff --git a/llm/src/ops/RotaryPosEmb.cc b/llm/src/ops/RotaryPosEmb.cc
index 0269f1df..c9c8af64 100644
--- a/llm/src/ops/RotaryPosEmb.cc
+++ b/llm/src/ops/RotaryPosEmb.cc
@@ -7,6 +7,7 @@ float q_buf[4096], k_buf[4096];
void RotaryPosEmb::forward(Matrix3D &query, Matrix3D &key, int start_idx, int len) {
PROFILE_START(profile_name);
int num_heads = query.m_dim_x;
+ int num_kv_heads = key.m_dim_x;
int head_embed = cos.m_dim_z;
int max_sqlen = cos.m_dim_y;
@@ -25,21 +26,40 @@ void RotaryPosEmb::forward(Matrix3D &query, Matrix3D &key, int sta
// rotate_half: torch.cat((-x2, x1), dim=-1)
int half = head_embed / 2;
+ // Query
for (int b = 0; b < num_heads; b++) {
for (int i = 0; i < len; i++) {
// first half
for (int j = 0; j < half; j++) {
q_buf[j] = -1 * query(b, i, j + half);
- k_buf[j] = -1 * key(b, i, j + half);
+ // k_buf[j] = -1 * key(b, i, j + half);
}
// second half
for (int j = half; j < head_embed; j++) {
q_buf[j] = query(b, i, j - half);
- k_buf[j] = key(b, i, j - half);
+ // k_buf[j] = key(b, i, j - half);
}
for (int j = 0; j < head_embed; j++) {
query(b, i, j) = ((query(b, i, j) * cos(0, i + start_idx, j)) + (q_buf[j] * sin(0, i + start_idx, j)));
+ // key(b, i, j) = ((key(b, i, j) * cos(0, i + start_idx, j)) + (k_buf[j] * sin(0, i + start_idx, j)));
+ }
+ }
+ }
+
+ // Key
+ for (int b = 0; b < num_kv_heads; b++) {
+ for (int i = 0; i < len; i++) {
+ // first half
+ for (int j = 0; j < half; j++) {
+ k_buf[j] = -1 * key(b, i, j + half);
+ }
+ // second half
+ for (int j = half; j < head_embed; j++) {
+ k_buf[j] = key(b, i, j - half);
+ }
+
+ for (int j = 0; j < head_embed; j++) {
key(b, i, j) = ((key(b, i, j) * cos(0, i + start_idx, j)) + (k_buf[j] * sin(0, i + start_idx, j)));
}
}
diff --git a/llm/tools/copy_rotary_emb.sh b/llm/tools/copy_rotary_emb.sh
new file mode 100755
index 00000000..9b8d6e21
--- /dev/null
+++ b/llm/tools/copy_rotary_emb.sh
@@ -0,0 +1,6 @@
+#!/bin/bash
+
+# Copy from layer 0 to layer 31
+for i in {0..31}; do
+ cp -r INT4/models/CodeLLaMA_7B_Instruct/decoder/layer${i}/self_attn/rotary_emb/* INT4/models/Mistral_7B/decoder/layer${i}/self_attn/rotary_emb/
+done
\ No newline at end of file
diff --git a/llm/tools/download_model.py b/llm/tools/download_model.py
index 23f4a015..d31cae69 100644
--- a/llm/tools/download_model.py
+++ b/llm/tools/download_model.py
@@ -71,7 +71,15 @@
"md5sum": "e3e9301866f47ab84817b46467ac49f6",
},
"Mistral_7B_v0.2_Instruct_fp32": {
- "url": "",
+ "url": "https://www.dropbox.com/scl/fi/6c8igzbega3xthqyjzguc/Mistral_7B_v0.2_Instruct.zip?rlkey=88ri16xi58cjk00l0d32luaky&dl=1",
+ "md5sum": "8daa04f2af5f0470c66eb45615ab07e2",
+ },
+ "LLaMA_3_8B_Instruct_fp32": {
+ "url": "https://www.dropbox.com/scl/fi/rj1n5ozget03qdp5t90oz/LLaMA_3_8B_Instruct.zip?rlkey=ugzbipwolkf6wzm50732e1ct5&dl=1",
+ "md5sum": "ae710f37a74f98d0d47da085672179b5",
+ },
+ "VILA1.5_8B_fp32": {
+ "url": "", # noqa: E501
"md5sum": "",
},
}
@@ -130,9 +138,17 @@
"url": "https://www.dropbox.com/scl/fi/fe4dkrnzc25bt166w6bby/StarCoder_15.5B.zip?rlkey=ml1x96uep2k03z78ci7s1c0yb&dl=1",
"md5sum": "0f16236c0aec0b32b553248cc78b8caf",
},
- "Misitral_7B_v0.2_Instruct_awq_int4": {
+ "Mistral_7B_v0.2_Instruct_awq_int4": {
"url": "https://www.dropbox.com/scl/fi/ssr6bn9a6l9d4havu04om/Mistral_7B_v0.2_Instruct.zip?rlkey=73yqj6pw300o3izwr43etjqkr&dl=1",
- "md5sum": "ee96bcdee3d09046719f7d31d7f023f4",
+ "md5sum": "ac897d408a702ae79252bc79bfbbb699",
+ },
+ "LLaMA_3_8B_Instruct_awq_int4": {
+ "url": "https://www.dropbox.com/scl/fi/zo9e82pnkxjbez3waic1w/LLaMA_3_8B_Instruct.zip?rlkey=nabq14qjzeaw8el7y5lh2oj8u&dl=1",
+ "md5sum": "8c44a5d7cb2a0406f8f1cbb785ed7e17",
+ },
+ "VILA1.5_8B_awq_int4": {
+ "url": "https://www.dropbox.com/scl/fi/mnodsv9ky44zxopp3nb4p/VILA1.5_8B.zip?rlkey=ppzd1av22zon53ae6ey01zqkh&dl=1", # noqa: E501
+ "md5sum": "9e9ab4e30f9fc7de69fadb3aae511456",
},
},
"QM_x86": {
@@ -190,7 +206,15 @@
},
"Mistral_7B_v0.2_Instruct_awq_int4": {
"url": "https://www.dropbox.com/scl/fi/2f7djt8z8lhkd60velfb3/Mistral_7B_v0.2_Instruct.zip?rlkey=gga6mh8trxf6durck4y4cyihe&dl=1",
- "md5sum": "22e8692d7481807b4151f28c54f112da",
+ "md5sum": "66f24d7ca1e12f573e172d608536f997",
+ },
+ "LLaMA_3_8B_Instruct_awq_int4": {
+ "url": "https://www.dropbox.com/scl/fi/h68a5isipths6a1e1eutg/LLaMA_3_8B_Instruct.zip?rlkey=63bk1uixu5q97l26u4q1mg0s3&dl=1",
+ "md5sum": "8540fec0fefa44e13e81748ff8edb231",
+ },
+ "VILA1.5_8B_awq_int4": {
+ "url": "https://www.dropbox.com/scl/fi/088d0vxu29jt0jmgw4es8/VILA1.5_8B.zip?rlkey=uhz1yoiovrckv4o48jmrl6s6i&dl=1", # noqa: E501
+ "md5sum": "1c0574fa1d4aa81616a655bc3436479c",
},
},
"QM_CUDA": {
diff --git a/llm/tools/export_model.sh b/llm/tools/export_model.sh
new file mode 100755
index 00000000..6accfe15
--- /dev/null
+++ b/llm/tools/export_model.sh
@@ -0,0 +1,60 @@
+#!/bin/bash
+
+# # E.g., Quantize and export Mistral-7B model
+# python tools/mistral_exporter.py --model ../../llm-awq-mistral/quant_cache/mistral-7b-w4-g32-awq-v2.pt --output models/Mistral_7B
+# python tools/rotary_emb_exporter.py
+# # For x86
+# python tools/model_quantizer.py --model_path models/Mistral_7B --method QM_x86
+# mkdir Mistral_7B_for_x86
+# mkdir Mistral_7B_for_x86/INT4
+# mkdir Mistral_7B_for_x86/INT4/models
+# mv INT4/models/Mistral_7B Mistral_7B_for_x86/INT4/models
+# cd Mistral_7B_for_x86/
+# zip -r Mistral_7B_v0.2_Instruct.zip INT4
+# cd ..
+# # For ARM
+# python tools/model_quantizer.py --model_path models/Mistral_7B --method QM_ARM
+# mkdir Mistral_7B_for_ARM
+# mkdir Mistral_7B_for_ARM/INT4
+# mkdir Mistral_7B_for_ARM/INT4/models
+# mv INT4/models/Mistral_7B Mistral_7B_for_ARM/INT4/models
+# cd Mistral_7B_for_ARM/
+# zip -r Mistral_7B_v0.2_Instruct.zip INT4
+# cd ..
+# # fp32
+# mkdir Mistral_7B_FP32
+# mkdir Mistral_7B_FP32/models
+# mv models/Mistral_7B Mistral_7B_FP32/models
+# cd Mistral_7B_FP32/
+# zip -r Mistral_7B_v0.2_Instruct.zip models
+# cd ..
+
+
+# E.g., Quantize and export LLaMA3-8B model
+python tools/llama3_exporter.py --model ../../llm-awq/quant_cache/llama3-8b-w4-g32-awq-v2.pt --output models/LLaMA_3_8B_Instruct
+python tools/rotary_emb_exporter.py
+# For ARM
+python tools/model_quantizer.py --model_path models/LLaMA_3_8B_Instruct --method QM_ARM
+mkdir LLaMA_3_8B_Instruct_for_ARM
+mkdir LLaMA_3_8B_Instruct_for_ARM/INT4
+mkdir LLaMA_3_8B_Instruct_for_ARM/INT4/models
+mv INT4/models/LLaMA_3_8B_Instruct LLaMA_3_8B_Instruct_for_ARM/INT4/models
+cd LLaMA_3_8B_Instruct_for_ARM/
+zip -r LLaMA_3_8B_Instruct.zip INT4
+cd ..
+# For x86
+python tools/model_quantizer.py --model_path models/LLaMA_3_8B_Instruct --method QM_x86
+mkdir LLaMA_3_8B_Instruct_for_x86
+mkdir LLaMA_3_8B_Instruct_for_x86/INT4
+mkdir LLaMA_3_8B_Instruct_for_x86/INT4/models
+mv INT4/models/LLaMA_3_8B_Instruct LLaMA_3_8B_Instruct_for_x86/INT4/models
+cd LLaMA_3_8B_Instruct_for_x86/
+zip -r LLaMA_3_8B_Instruct.zip INT4
+cd ..
+# fp32
+mkdir LLaMA_3_8B_Instruct_FP32
+mkdir LLaMA_3_8B_Instruct_FP32/models
+mv models/LLaMA_3_8B_Instruct LLaMA_3_8B_Instruct_FP32/models
+cd LLaMA_3_8B_Instruct_FP32/
+zip -r LLaMA_3_8B_Instruct.zip models
+cd ..
diff --git a/llm/tools/llama3_exporter.py b/llm/tools/llama3_exporter.py
new file mode 100644
index 00000000..0c49baca
--- /dev/null
+++ b/llm/tools/llama3_exporter.py
@@ -0,0 +1,173 @@
+"""Implementation of exporting LLaMA-3 PyTorch model to TinyChatEngine format.
+
+Usage:
+ python llama3_exporter.py