diff --git a/cactus/engine/engine.h b/cactus/engine/engine.h index 77594ed6..ddc4bfff 100644 --- a/cactus/engine/engine.h +++ b/cactus/engine/engine.h @@ -85,7 +85,7 @@ struct Config { float max_pixels_tolerance = 2.0f; bool do_image_splitting = true; - enum class ModelType {QWEN = 0, GEMMA = 1, SMOL = 2, NOMIC = 3, LFM2 = 5, SIGLIP2 = 6, WHISPER = 7}; + enum class ModelType {QWEN = 0, GEMMA = 1, SMOL = 2, NOMIC = 3, LFM2 = 5, SIGLIP2 = 6, WHISPER = 7, PHI3 = 8}; ModelType model_type = ModelType::QWEN; enum class ModelVariant {DEFAULT = 0, VLM = 1, EXTRACT = 2, RAG = 3}; @@ -156,7 +156,7 @@ class Tokenizer { void set_corpus_dir(const std::string& dir) { corpus_dir_ = dir; } protected: - enum class ModelType { UNKNOWN, QWEN, GEMMA, LFM2, SMOL, BERT, WHISPER}; + enum class ModelType { UNKNOWN, QWEN, GEMMA, LFM2, SMOL, BERT, WHISPER, PHI3}; ModelType model_type_ = ModelType::UNKNOWN; enum class ModelVariant { DEFAULT, VLM, EXTRACT, RAG}; ModelVariant model_variant_ = ModelVariant::DEFAULT; @@ -174,6 +174,7 @@ class Tokenizer { std::string format_lfm2_style(const std::vector& messages, bool add_generation_prompt, const std::string& tools_json) const; std::string format_lfm2_vl_style(const std::vector& messages, bool add_generation_prompt, const std::string& tools_json) const; std::string format_smol_style(const std::vector& messages, bool add_generation_prompt, const std::string& tools_json) const; + std::string format_phi3_style(const std::vector& messages, bool add_generation_prompt, const std::string& tools_json) const; }; class BPETokenizer : public Tokenizer { diff --git a/cactus/engine/engine_model.cpp b/cactus/engine/engine_model.cpp index f28d424b..fb62b189 100644 --- a/cactus/engine/engine_model.cpp +++ b/cactus/engine/engine_model.cpp @@ -381,6 +381,7 @@ bool Config::from_json(const std::string& config_path) { if (value == "gemma" || value == "GEMMA") model_type = ModelType::GEMMA; else if (value == "lfm2" || value == "LFM2") model_type = ModelType::LFM2; else if (value == "smol" || value == "SMOL" || value == "Smol") model_type = ModelType::SMOL; + else if (value == "phi3" || value == "PHI3" || value == "Phi3") model_type = ModelType::PHI3; else if (value == "bert" || value == "BERT") model_type = ModelType::NOMIC; else if (value == "whisper" || value == "WHISPER") model_type = ModelType::WHISPER; else model_type = ModelType::QWEN; @@ -477,6 +478,8 @@ std::unique_ptr create_model(const std::string& model_folder) { return std::make_unique(config); case Config::ModelType::SMOL: return std::make_unique(config); + case Config::ModelType::PHI3: + return std::make_unique(config); case Config::ModelType::NOMIC: return std::make_unique(config); case Config::ModelType::WHISPER: diff --git a/cactus/engine/engine_sp.cpp b/cactus/engine/engine_sp.cpp index 7f1e2bbf..431382a8 100644 --- a/cactus/engine/engine_sp.cpp +++ b/cactus/engine/engine_sp.cpp @@ -215,7 +215,8 @@ std::string SPTokenizer::preprocess_text(const std::string& text) const { for (size_t i = text.find_first_not_of(" "); i < text.length(); i++) { char c = text[i]; - if (c == ' ') { + // Phi-3 treats newlines like spaces in SentencePiece encoding + if (c == ' ' || (c == '\n' && model_type_ == ModelType::PHI3)) { processed += "▁"; } else { processed += c; @@ -383,6 +384,14 @@ std::vector SPTokenizer::split_with_special_tokens(const std::strin } result.push_back(best_special_token); start = best_match_pos + best_match_len; + // Phi-3 specific: rstrip whitespace after special tokens (except <|endoftext|>) + // This matches `rstrip=True` behavior for Phi-3 chat tokens + if (model_type_ == ModelType::PHI3 && best_special_token != "<|endoftext|>") { + while (start < text.size() && (text[start] == ' ' || text[start] == '\n' || + text[start] == '\t' || text[start] == '\r')) { + start++; + } + } } else { if (start < text.size()) { result.push_back(text.substr(start)); diff --git a/cactus/engine/engine_tokenizer.cpp b/cactus/engine/engine_tokenizer.cpp index 6b6a9305..e0e48e95 100644 --- a/cactus/engine/engine_tokenizer.cpp +++ b/cactus/engine/engine_tokenizer.cpp @@ -37,6 +37,9 @@ void Tokenizer::detect_model_type(const std::string& config_path) { } else if (line.find("whisper") != std::string::npos) { model_type_ = ModelType::WHISPER; break; + } else if (line.find("phi3") != std::string::npos || line.find("phi-3") != std::string::npos) { + model_type_ = ModelType::PHI3; + break; } else { model_type_ = ModelType::UNKNOWN; } @@ -93,6 +96,8 @@ std::string Tokenizer::format_chat_prompt(const std::vector& messag return format_lfm2_style(messages, add_generation_prompt, tools_json); case ModelType::SMOL: return format_smol_style(messages, add_generation_prompt, tools_json); + case ModelType::PHI3: + return format_phi3_style(messages, add_generation_prompt, tools_json); default: return format_qwen_style(messages, add_generation_prompt, tools_json); } @@ -413,6 +418,32 @@ std::string Tokenizer::format_smol_style(const std::vector& message return result; } +std::string Tokenizer::format_phi3_style(const std::vector& messages, bool add_generation_prompt, const std::string& tools_json) const { + if (!tools_json.empty()) { + return "ERROR: Tool calls are currently not supported for Phi-3 models"; + } + + std::string result; + + for (const auto& msg : messages) { + if (msg.role == "system") { + result += "<|system|>\n"; + } else if (msg.role == "user") { + result += "<|user|>\n"; + } else if (msg.role == "assistant") { + result += "<|assistant|>\n"; + } + result += msg.content; + result += "<|end|>\n"; + } + + if (add_generation_prompt) { + result += "<|assistant|>\n"; + } + + return result; +} + } // namespace engine } // namespace cactus \ No newline at end of file diff --git a/cactus/models/model.h b/cactus/models/model.h index b1865d91..78db6acf 100644 --- a/cactus/models/model.h +++ b/cactus/models/model.h @@ -138,6 +138,47 @@ class SmolModel : public Model{ }; +class Phi3Model : public Model { +public: + Phi3Model(); + explicit Phi3Model(const Config& config); + ~Phi3Model() override = default; + +protected: + size_t build_attention(CactusGraph* gb, size_t normalized_input, uint32_t layer_idx, + ComputeBackend backend, bool use_cache = false, size_t position_offset = 0) override; + + size_t build_mlp(CactusGraph* gb, size_t normalized_h, uint32_t layer_idx, + ComputeBackend backend) const override; + + size_t build_transformer_block(CactusGraph* gb, size_t hidden, uint32_t layer_idx, + ComputeBackend backend, bool use_cache = false, size_t position_offset = 0) override; + + size_t forward(const std::vector& tokens, bool use_cache = false) override; + void load_weights_to_graph(CactusGraph* gb) override; + +private: + struct WeightNodeIDs { + size_t output_weight; + size_t output_norm_weight; + + struct LayerWeights { + size_t attn_q_weight; + size_t attn_k_weight; + size_t attn_v_weight; + size_t attn_output_weight; + size_t input_layernorm_weight; + size_t ffn_gate_weight; + size_t ffn_up_weight; + size_t ffn_down_weight; + size_t post_attention_layernorm_weight; + }; + + std::vector layers; + } weight_nodes_; +}; + + class Siglip2VisionModel : public Model { friend class Lfm2VlModel; diff --git a/cactus/models/model_phi.cpp b/cactus/models/model_phi.cpp new file mode 100644 index 00000000..bf8c1709 --- /dev/null +++ b/cactus/models/model_phi.cpp @@ -0,0 +1,165 @@ +#include "model.h" +#include "../graph/graph.h" +#include +#include +#include + +namespace cactus { +namespace engine { + +Phi3Model::Phi3Model() : Model() {} + +Phi3Model::Phi3Model(const Config& config) : Model(config) { + weight_nodes_.layers.resize(config.num_layers); +} + +void Phi3Model::load_weights_to_graph(CactusGraph* gb) { + embedding_node_id_ = gb->mmap_embeddings(embedding_file_path_); + weight_nodes_.output_norm_weight = gb->mmap_weights(model_folder_path_ + "/output_norm.weights"); + + if (config_.tie_word_embeddings) { + weight_nodes_.output_weight = embedding_node_id_; + output_weight_node_id_ = embedding_node_id_; + } else { + weight_nodes_.output_weight = gb->mmap_weights(model_folder_path_ + "/output_weight.weights"); + output_weight_node_id_ = weight_nodes_.output_weight; + } + + for (uint32_t i = 0; i < config_.num_layers; i++) { + auto& layer = weight_nodes_.layers[i]; + std::string layer_prefix = model_folder_path_ + "/layer_" + std::to_string(i) + "_"; + layer.attn_q_weight = gb->mmap_weights(layer_prefix + "attn_q.weights"); + layer.attn_k_weight = gb->mmap_weights(layer_prefix + "attn_k.weights"); + layer.attn_v_weight = gb->mmap_weights(layer_prefix + "attn_v.weights"); + layer.attn_output_weight = gb->mmap_weights(layer_prefix + "attn_output.weights"); + layer.input_layernorm_weight = gb->mmap_weights(layer_prefix + "input_norm.weights"); + layer.ffn_gate_weight = gb->mmap_weights(layer_prefix + "ffn_gate.weights"); + layer.ffn_up_weight = gb->mmap_weights(layer_prefix + "ffn_up.weights"); + layer.ffn_down_weight = gb->mmap_weights(layer_prefix + "ffn_down.weights"); + layer.post_attention_layernorm_weight = gb->mmap_weights(layer_prefix + "post_attn_norm.weights"); + } +} + +size_t Phi3Model::build_attention(CactusGraph* gb, size_t normalized_input, uint32_t layer_idx, + ComputeBackend backend, bool use_cache, size_t position_offset) { + const auto& layer = weight_nodes_.layers[layer_idx]; + + auto q_proj = gb->matmul(normalized_input, layer.attn_q_weight, true, backend); + auto k_proj = gb->matmul(normalized_input, layer.attn_k_weight, true, backend); + auto v_proj = gb->matmul(normalized_input, layer.attn_v_weight, true, backend); + + const auto& q_shape = gb->get_output_buffer(q_proj).shape; + size_t seq_len = q_shape[0]; + + auto q_proj_4d = gb->reshape(q_proj, {1, seq_len, config_.attention_heads, config_.attention_head_dim}); + auto k_proj_4d = gb->reshape(k_proj, {1, seq_len, config_.attention_kv_heads, config_.attention_head_dim}); + auto v_proj_4d = gb->reshape(v_proj, {1, seq_len, config_.attention_kv_heads, config_.attention_head_dim}); + + if (config_.rope_theta > 0) { + q_proj_4d = gb->rope(q_proj_4d, config_.rope_theta, position_offset); + k_proj_4d = gb->rope(k_proj_4d, config_.rope_theta, position_offset); + } + + size_t final_k = k_proj_4d; + size_t final_v = v_proj_4d; + + if (use_cache && !kv_cache_.is_empty()) { + auto k_view = kv_cache_.get_key_view(layer_idx); + auto v_view = kv_cache_.get_value_view(layer_idx); + + if (k_view.ptr2 == nullptr && v_view.ptr2 == nullptr) { + size_t cache_k_node = gb->input({1, kv_cache_.current_seq_len, config_.attention_kv_heads, config_.attention_head_dim}, kv_cache_.precision); + size_t cache_v_node = gb->input({1, kv_cache_.current_seq_len, config_.attention_kv_heads, config_.attention_head_dim}, kv_cache_.precision); + + gb->set_input(cache_k_node, k_view.ptr1, kv_cache_.precision); + gb->set_input(cache_v_node, v_view.ptr1, kv_cache_.precision); + + final_k = gb->concat(cache_k_node, k_proj_4d, 1); + final_v = gb->concat(cache_v_node, v_proj_4d, 1); + } else { + size_t cache_k_node = gb->input({1, kv_cache_.current_seq_len, config_.attention_kv_heads, config_.attention_head_dim}, kv_cache_.precision); + size_t cache_v_node = gb->input({1, kv_cache_.current_seq_len, config_.attention_kv_heads, config_.attention_head_dim}, kv_cache_.precision); + + gb->set_input(cache_k_node, kv_cache_.get_key_ptr(layer_idx), kv_cache_.precision); + gb->set_input(cache_v_node, kv_cache_.get_value_ptr(layer_idx), kv_cache_.precision); + + final_k = gb->concat(cache_k_node, k_proj_4d, 1); + final_v = gb->concat(cache_v_node, v_proj_4d, 1); + } + } + + if (use_cache) { + cache_k_output_nodes_[layer_idx] = final_k; + cache_v_output_nodes_[layer_idx] = final_v; + } + + auto attn_output_4d = gb->attention(q_proj_4d, final_k, final_v, attention_scale_, position_offset); + auto attn_output = gb->reshape(attn_output_4d, {seq_len, config_.attention_head_dim * config_.attention_heads}); + return gb->matmul(attn_output, layer.attn_output_weight, true, backend); +} + +size_t Phi3Model::build_mlp(CactusGraph* gb, size_t normalized_h, uint32_t layer_idx, + ComputeBackend backend) const { + const auto& layer = weight_nodes_.layers[layer_idx]; + size_t gate_output = gb->matmul(normalized_h, layer.ffn_gate_weight, true, backend); + size_t up_output = gb->matmul(normalized_h, layer.ffn_up_weight, true, backend); + size_t gate_silu = gb->silu(gate_output); + size_t gated = gb->multiply(gate_silu, up_output); + return gb->matmul(gated, layer.ffn_down_weight, true, backend); +} + +size_t Phi3Model::build_transformer_block(CactusGraph* gb, size_t hidden, uint32_t layer_idx, + ComputeBackend backend, bool use_cache, size_t position_offset) { + const auto& layer = weight_nodes_.layers[layer_idx]; + auto normalized_input = gb->rms_norm(hidden, layer.input_layernorm_weight, config_.layer_norm_eps); + auto attn_output = build_attention(gb, normalized_input, layer_idx, backend, use_cache, position_offset); + auto after_attention = gb->add(hidden, attn_output); + auto normalized_after_attention = gb->rms_norm(after_attention, layer.post_attention_layernorm_weight, config_.layer_norm_eps); + auto mlp_output = build_mlp(gb, normalized_after_attention, layer_idx, backend); + return gb->add(after_attention, mlp_output); +} + +size_t Phi3Model::forward(const std::vector& tokens, bool use_cache) { + if (!initialized_ || !graph_handle_) { + throw std::runtime_error("Model not initialized - call init() first"); + } + + if (tokens.empty()) { + throw std::runtime_error("Token sequence cannot be empty"); + } + + auto* gb = static_cast(graph_handle_); + gb->soft_reset(); + + auto seq_len = static_cast(tokens.size()); + + size_t position_offset = use_cache ? kv_cache_.get_total_seq_len() : 0; + + auto backend = config_.default_backend == Config::Backend::CPU + ? ComputeBackend::CPU + : ComputeBackend::NPU; + + auto input_node_id = gb->input({seq_len}, Precision::FP32); + auto hidden = gb->embedding(embedding_node_id_, input_node_id); + + static std::set skip_layers = {}; + for (uint32_t layer_idx = 0; layer_idx < config_.num_layers; layer_idx++) { + if (skip_layers.count(layer_idx)) { + continue; + } + hidden = build_transformer_block(gb, hidden, layer_idx, backend, use_cache, position_offset); + } + + auto final_hidden = gb->rms_norm(hidden, weight_nodes_.output_norm_weight, config_.layer_norm_eps); + + std::vector input_data(seq_len); + for (size_t i = 0; i < seq_len; i++) { + input_data[i] = static_cast(tokens[i]); + } + gb->set_input(input_node_id, input_data.data(), Precision::FP32); + + return final_hidden; +} + +} +} \ No newline at end of file diff --git a/tests/test_engine.cpp b/tests/test_engine.cpp index d8611802..65500719 100644 --- a/tests/test_engine.cpp +++ b/tests/test_engine.cpp @@ -10,10 +10,47 @@ const char* g_transcribe_model_path = std::getenv("CACTUS_TEST_TRANSCRIBE_MODEL" const char* g_audio_file_path = "../assets/test.wav"; const char* g_whisper_prompt = "<|startoftranscript|><|en|><|transcribe|><|notimestamps|>"; -const char* g_options = R"({ - "max_tokens": 256, - "stop_sequences": ["<|im_end|>", ""] - })"; +// Detect model type from `path/config` and return stop sequences +std::string get_stop_sequences_for_model() { + if (!g_model_path) { + // default = Phi-3 + return R"(["<|end|>", "<|endoftext|>"])"; + } + std::string model_path_str(g_model_path); + std::string config_path = model_path_str + "/config.txt"; + + std::ifstream config_file(config_path); + if (config_file.is_open()) { + std::string line; + while (std::getline(config_file, line)) { + if (line.find("qwen") != std::string::npos || line.find("Qwen") != std::string::npos) { + return R"(["<|im_end|>"])"; + } else if (line.find("gemma") != std::string::npos || line.find("Gemma") != std::string::npos) { + return R"([""])"; + } else if (line.find("phi") != std::string::npos || line.find("Phi") != std::string::npos) { + return R"(["<|end|>", "<|endoftext|>"])"; + } + } + config_file.close(); + } + + if (model_path_str.find("qwen") != std::string::npos || model_path_str.find("Qwen") != std::string::npos) { + return R"(["<|im_end|>"])"; + } else if (model_path_str.find("gemma") != std::string::npos || model_path_str.find("Gemma") != std::string::npos) { + return R"([""])"; + } else if (model_path_str.find("phi") != std::string::npos || model_path_str.find("Phi") != std::string::npos) { + return R"(["<|end|>", "<|endoftext|>"])"; + } + + return R"(["<|end|>", "<|endoftext|>"])"; +} + +std::string get_options_json() { + return R"({"max_tokens": 256, "stop_sequences": )" + get_stop_sequences_for_model() + R"(})"; +} + +std::string g_options_str = get_options_json(); +const char* g_options = g_options_str.c_str(); template bool run_test(const char* title, const char* messages, TestFunc test_logic, @@ -232,16 +269,29 @@ bool test_embeddings() { const char* texts[] = {"My name is Henry Ndubuaku", "Your name is Henry Ndubuaku"}; std::vector emb1(2048), emb2(2048); - size_t dim1, dim2; + size_t dim1 = 0, dim2 = 0; Timer t1; - cactus_embed(model, texts[0], emb1.data(), emb1.size() * sizeof(float), &dim1); + int result1 = cactus_embed(model, texts[0], emb1.data(), emb1.size() * sizeof(float), &dim1); double time1 = t1.elapsed_ms(); + // Check if embeddings are supported by the model + if (result1 < 0) { + std::cout << "⊘ SKIP │ embeddings │ model does not support embeddings (causal LM)" << std::endl; + cactus_destroy(model); + return true; + } + Timer t2; - cactus_embed(model, texts[1], emb2.data(), emb2.size() * sizeof(float), &dim2); + int result2 = cactus_embed(model, texts[1], emb2.data(), emb2.size() * sizeof(float), &dim2); double time2 = t2.elapsed_ms(); + if (result2 < 0 || dim1 == 0 || dim2 == 0) { + std::cout << "[✗] FAIL │ embeddings │ embedding extraction failed" << std::endl; + cactus_destroy(model); + return false; + } + float dot = 0, norm1 = 0, norm2 = 0; for (size_t i = 0; i < dim1; ++i) { dot += emb1[i] * emb2[i]; diff --git a/tools/convert_hf.py b/tools/convert_hf.py index e3f04289..d4240e7b 100644 --- a/tools/convert_hf.py +++ b/tools/convert_hf.py @@ -249,6 +249,8 @@ def _cfg_get(c, key, default=None): detected_model_type = 'lfm2' elif 'qwen' in model_type_str: detected_model_type = 'qwen' + elif 'phi3' in model_type_str or 'phi-3' in model_type_str: + detected_model_type = 'phi3' elif 'llama' in model_type_str: if('smol' in str(output_dir)): detected_model_type = 'smol' @@ -601,6 +603,34 @@ def _cfg_get(c, key, default=None): save_tensor_with_header(v_weight, output_dir / f'layer_{i}_attn_v.weights', precision, transpose=False, stats_tracker=quantization_stats, args=args, model_type=detected_model_type) saved_tensor_full_names.add(attn_name) found = True + + # Fused QKV projection handling + qkv_name = layer_prefix + 'self_attn.qkv_proj.weight' + if qkv_name in state_dict: + qkv_weight = state_dict[qkv_name] + num_heads = model_config['attention_heads'] + num_kv_heads = model_config['attention_kv_heads'] + head_dim = model_config['attention_head_dim'] + q_size = num_heads * head_dim + kv_size = num_kv_heads * head_dim + q_weight = qkv_weight[:q_size, :] + k_weight = qkv_weight[q_size:q_size + kv_size, :] + v_weight = qkv_weight[q_size + kv_size:, :] + save_tensor_with_header(q_weight, output_dir / f'layer_{i}_attn_q.weights', precision, transpose=False, stats_tracker=quantization_stats, args=args, model_type=detected_model_type) + save_tensor_with_header(k_weight, output_dir / f'layer_{i}_attn_k.weights', precision, transpose=False, stats_tracker=quantization_stats, args=args, model_type=detected_model_type) + save_tensor_with_header(v_weight, output_dir / f'layer_{i}_attn_v.weights', precision, transpose=False, stats_tracker=quantization_stats, args=args, model_type=detected_model_type) + saved_tensor_full_names.add(qkv_name) + + # Fused gate_up projection handling + gate_up_name = layer_prefix + 'mlp.gate_up_proj.weight' + if gate_up_name in state_dict: + gate_up_weight = state_dict[gate_up_name] + intermediate_size = gate_up_weight.shape[0] // 2 + gate_weight = gate_up_weight[:intermediate_size, :] + up_weight = gate_up_weight[intermediate_size:, :] + save_tensor_with_header(gate_weight, output_dir / f'layer_{i}_ffn_gate.weights', precision, transpose=False, stats_tracker=quantization_stats, args=args, model_type=detected_model_type) + save_tensor_with_header(up_weight, output_dir / f'layer_{i}_ffn_up.weights', precision, transpose=False, stats_tracker=quantization_stats, args=args, model_type=detected_model_type) + saved_tensor_full_names.add(gate_up_name) if saved_tensor_full_names != set(state_dict.keys()): print(f"Warning: Unsaved tensors: {set(state_dict.keys()) - saved_tensor_full_names}") @@ -1209,6 +1239,8 @@ def write_merges_file(merges_list): content = token_info.get('content', '') token_id = int(token_id_str) + special_tokens[token_id] = content + print(f" Found special token: {content} (ID: {token_id})") tool_related = ['', '', '', '', '', '',