Skip to content

Commit

Permalink
Merge branch 'refs/heads/dev'
Browse files Browse the repository at this point in the history
  • Loading branch information
turboderp committed Jul 11, 2024
2 parents 106a9d1 + b3e07ee commit ca9aecf
Show file tree
Hide file tree
Showing 59 changed files with 1,191 additions and 290 deletions.
15 changes: 14 additions & 1 deletion eval/humaneval.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from exllamav2 import model_init
from exllamav2 import ExLlamaV2Cache, ExLlamaV2Cache_Q4, ExLlamaV2Cache_Q6, ExLlamaV2Cache_Q8
from exllamav2.generator import ExLlamaV2DynamicGenerator, ExLlamaV2DynamicJob, ExLlamaV2Sampler
import argparse, contextlib
import argparse, contextlib, subprocess
import util

# Args
Expand All @@ -20,6 +20,7 @@
parser.add_argument("--max_tokens", type = int, default = 768, help = "Max number of tokens for each completion")
parser.add_argument("-pf", "--prompt_format", type = str, help = "Instruct format to apply. Default is raw completion (for base models) ")
parser.add_argument("-v", "--verbose", action = "store_true", help = "Spam completions to console while generating")
parser.add_argument("-e", "--eval", action = "store_true", help = "Run evaluation script on output file after sampling")
model_init.add_args(parser)
args = parser.parse_args()

Expand Down Expand Up @@ -52,6 +53,13 @@
"<|start_header_id|>assistant<|end_header_id|>\n\n"
"Sure! Here is how you might implement the function:\n\n```python\n{{problem}} ",
" "
),
"gemma": (
"<bos><start_of_turn>user\n"
"Complete the following Python function:\n\n{{problem}}<|eot_id|>"
"<start_of_turn>model\n"
"```python\n{{problem}} ",
" "
)
}

Expand Down Expand Up @@ -192,3 +200,8 @@
print(f" -- Saving: {args.output}")
write_jsonl(args.output, samples)

# Optionally launch eval script

if args.eval:
subprocess.run(["evaluate_functional_correctness", args.output])

9 changes: 6 additions & 3 deletions examples/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@

parser.add_argument("-ngram", "--ngram_decoding", action = "store_true", help = "Use n-gram speculative decoding")

parser.add_argument("-pt", "--print_timings", action = "store_true", help = "Output timings after each prompt")
parser.add_argument("-pt", "--print_timings", action = "store_true", help = "Output timings/stats after each prompt")
parser.add_argument("-amnesia", "--amnesia", action = "store_true", help = "Forget context after every response")

# Arrrgs
Expand Down Expand Up @@ -235,7 +235,9 @@ def get_tokenized_context(max_len):

# Stop conditions

generator.set_stop_conditions(prompt_format.stop_conditions(tokenizer))
sc = prompt_format.stop_conditions(tokenizer)
sc = [x for x in sc if x]
generator.set_stop_conditions(sc)

# ANSI color codes

Expand Down Expand Up @@ -393,8 +395,9 @@ def get_tokenized_context(max_len):
else:
sd_stats = ""

ctx_tokens = active_context.shape[-1]
print()
print(col_sysprompt + f"(Response: {response_tokens} tokens, {speed:.2f} tokens/second{sd_stats})" + col_default)
print(col_sysprompt + f"(Context: {ctx_tokens} tokens, response: {response_tokens} tokens, {speed:.2f} tokens/second{sd_stats})" + col_default)

# Optionally forget context after each response

Expand Down
1 change: 1 addition & 0 deletions examples/chat_prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,7 @@ def subs_prompt(self):
def stop_conditions(self, tokenizer):
return \
[tokenizer.eos_token_id,
tokenizer.single_id("<|im_end|>"),
"""<|im_end|>"""]

def encoding_options(self):
Expand Down
2 changes: 2 additions & 0 deletions examples/dynamic_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ def main():
if use_draft_model:

draft_config = ExLlamaV2Config(draft_model_dir)
draft_config.arch_compat_overrides()
draft_model = ExLlamaV2(draft_config)

draft_cache = ExLlamaV2Cache(
Expand All @@ -155,6 +156,7 @@ def main():
# 2048, which will also be the limit of the chunk size for prefill used by the dynamic generator.

config = ExLlamaV2Config(model_dir)
config.arch_compat_overrides()
config.max_input_len = max_chunk_size
config.max_attention_size = max_chunk_size ** 2
model = ExLlamaV2(config)
Expand Down
1 change: 1 addition & 0 deletions examples/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

model_dir = "/mnt/str/models/mistral-7b-exl2/4.0bpw"
config = ExLlamaV2Config(model_dir)
config.arch_compat_overrides()
model = ExLlamaV2(config)
cache = ExLlamaV2Cache(model, max_seq_len = 32768, lazy = True)
model.load_autosplit(cache, progress = True)
Expand Down
1 change: 1 addition & 0 deletions examples/inference_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
async def main():
model_dir = "/mnt/str/models/llama3-8b-exl2/4.0bpw"
config = ExLlamaV2Config(model_dir)
config.arch_compat_overrides()
model = ExLlamaV2(config)
cache = ExLlamaV2Cache(model, lazy = True)
model.load_autosplit(cache, progress = True)
Expand Down
1 change: 1 addition & 0 deletions examples/inference_banned_strings.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

model_dir = "/mnt/str/models/llama3-8b-instruct-exl2/6.0bpw/"
config = ExLlamaV2Config(model_dir)
config.arch_compat_overrides()
model = ExLlamaV2(config)
cache = ExLlamaV2Cache(model, lazy = True)
model.load_autosplit(cache, progress = True)
Expand Down
1 change: 1 addition & 0 deletions examples/inference_cfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

model_dir = "/mnt/str/models/llama3-8b-instruct-exl2/4.0bpw"
config = ExLlamaV2Config(model_dir)
config.arch_compat_overrides()
model = ExLlamaV2(config)
cache = ExLlamaV2Cache(model, max_seq_len = 32768, lazy = True)
model.load_autosplit(cache, progress = True)
Expand Down
1 change: 1 addition & 0 deletions examples/inference_dedup.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

model_dir = "/mnt/str/models/llama3-8b-instruct-exl2/4.0bpw"
config = ExLlamaV2Config(model_dir)
config.arch_compat_overrides()
model = ExLlamaV2(config)
cache = ExLlamaV2Cache(model, max_seq_len = 8192, lazy = True)
model.load_autosplit(cache, progress = True)
Expand Down
1 change: 1 addition & 0 deletions examples/inference_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

model_dir = "/mnt/str/models/mistral-7b-exl2/4.0bpw"
config = ExLlamaV2Config(model_dir)
config.arch_compat_overrides()
model = ExLlamaV2(config)
cache = ExLlamaV2Cache(model, max_seq_len = 32768, lazy = True)
model.load_autosplit(cache, progress = True)
Expand Down
1 change: 1 addition & 0 deletions examples/inference_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

model_dir = "/mnt/str/models/llama2-7b-exl2/5.0bpw"
config = ExLlamaV2Config(model_dir)
config.arch_compat_overrides()
model = ExLlamaV2(config)
cache = ExLlamaV2Cache(model, max_seq_len = 32768, lazy = True)
model.load_autosplit(cache, progress = True)
Expand Down
2 changes: 2 additions & 0 deletions examples/inference_speculative.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,14 @@

draft_model_dir = "/mnt/str/models/qwen2-1.5b-instruct-exl2/4.0bpw"
draft_config = ExLlamaV2Config(draft_model_dir)
draft_config.arch_compat_overrides()
draft_model = ExLlamaV2(draft_config)
draft_cache = ExLlamaV2Cache(draft_model, max_seq_len = total_cache_tokens, lazy = True)
draft_model.load_autosplit(draft_cache, progress = True)

model_dir = "/mnt/str/models/qwen2-72b-instruct-exl2/6.0bpw"
config = ExLlamaV2Config(model_dir)
config.arch_compat_overrides()
model = ExLlamaV2(config)
cache = ExLlamaV2Cache(model, max_seq_len = total_cache_tokens, lazy = True)
model.load_autosplit(cache, progress = True)
Expand Down
1 change: 1 addition & 0 deletions examples/inference_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

model_dir = "/mnt/str/models/mistral-7b-exl2/4.0bpw"
config = ExLlamaV2Config(model_dir)
config.arch_compat_overrides()
model = ExLlamaV2(config)
cache = ExLlamaV2Cache(model, lazy = True)
model.load_autosplit(cache, progress = True)
Expand Down
9 changes: 8 additions & 1 deletion examples/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,12 @@ def format_prompt(prompt_format, sp, p):
f"{p}<|im_end|>\n"
f"<|im_start|>assistant\n"
)
elif prompt_format == "gemma":
return (
f"<bos><start_of_turn>user\n"
f"{p}<end_of_turn>\n"
f"<start_of_turn>model\n"
)

def get_stop_conditions(prompt_format, tokenizer):
if prompt_format == "llama":
Expand All @@ -37,7 +43,8 @@ def get_stop_conditions(prompt_format, tokenizer):
return [tokenizer.single_id("<|eot_id|>")]
elif prompt_format == "granite":
return [tokenizer.eos_token_id, "\n\nQuestion:"]

elif prompt_format == "gemma":
return [tokenizer.eos_token_id, "<end_of_turn>"]

# Cached dataset loader

Expand Down
108 changes: 108 additions & 0 deletions exllamav2/architecture.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,12 @@
["ln_2"]]
layer_keys_yi_norms = [["ln1", "input_layernorm"],
["ln2", "post_attention_layernorm"]]
layer_keys_gemma2_norms = [["input_layernorm"],
["post_attention_layernorm"],
["pre_feedforward_layernorm"],
["post_feedforward_layernorm"]]
layer_keys_internlm2_norms = [["attention_norm"],
["ffn_norm"]]
layer_keys_llama_attn = [["self_attn.q_proj"],
["self_attn.k_proj"],
["self_attn.v_proj"],
Expand All @@ -17,6 +23,10 @@
["self_attn.c_attn", "self_attn.k_proj"],
["self_attn.c_attn", "self_attn.v_proj"],
["self_attn.o_proj"]]
layer_keys_internlm2_attn = [["self_attn.wqkv", "self_attn.q_proj"],
["self_attn.wqkv", "self_attn.k_proj"],
["self_attn.wqkv", "self_attn.v_proj"],
["self_attn.o_proj"]]
layer_keys_dbrx_attn = [["self_attn.Wqkv", "self_attn.q_proj"],
["self_attn.Wqkv", "self_attn.k_proj"],
["self_attn.Wqkv", "self_attn.v_proj"],
Expand All @@ -28,6 +38,9 @@
layer_keys_llama_mlp = [["mlp.down_proj"],
["mlp.gate_proj"],
["mlp.up_proj"]]
layer_keys_internlm2_mlp = [["feed_forward.w1"],
["feed_forward.w2"],
["feed_forward.w3"]]
layer_keys_phi3_mlp = [["mlp.down_proj"],
["mlp.gate_up_proj", "mlp.gate_proj"],
["mlp.gate_up_proj", "mlp.up_proj"]]
Expand Down Expand Up @@ -76,6 +89,10 @@
("$h.", "model.layers."),
("$wte.", "model.embed_tokens."),
("$wpe.", "model.wpe.")]
internlm2_keymap = [("$output.", "lm_head."),
("$model.tok_embeddings.", "model.embed_tokens."),
(".attention.", ".self_attn."),
(".wo.", ".o_proj.")]

class RopeStyle(Enum):
NONE = 0
Expand All @@ -100,6 +117,18 @@ def __init__(self, arch_string, read_config):
self.orig_weights_transposed = False
self.logit_scale_basedim = False

self.norm_key_1_post = None
self.norm_key_2_post = None

self.swa = False
self.alternating_swa = False

self.eager_attn_only = False
self.clamp_hidden_states = False
self.residual_stream_fp32 = False

self.fused_qkv_altpack = False

# Mistral

if arch_string == "MistralForCausalLM":
Expand Down Expand Up @@ -305,6 +334,45 @@ def __init__(self, arch_string, read_config):
self.mqa = False
self.scale_attn_weights = False

# Gemma2

if arch_string == "Gemma2ForCausalLM":
arch_recognized = True
self.layer_keys += \
layer_keys_gemma2_norms + \
layer_keys_llama_attn + \
layer_keys_llama_mlp
self.expect_keys += \
expect_keys_gemma
self.norm_eps_key = "rms_norm_eps"
self.attention_bias_qkv = False
self.attention_bias_o = False
self.mlp_bias = False
self.mlp_gate = True
self.mlp_key_gate = ".mlp.gate_proj"
self.mlp_key_up = ".mlp.up_proj"
self.mlp_key_down = ".mlp.down_proj"
self.mlp_act_func = "gelu"
self.is_moe = False
self.norm = "rmsnorm"
self.lm_head_key = "model.embed_tokens"
self.normalize_embeddings = True
self.norm_key_1 = ".input_layernorm"
self.norm_key_1_post = ".post_attention_layernorm"
self.norm_key_2 = ".pre_feedforward_layernorm"
self.norm_key_2_post = ".post_feedforward_layernorm"
self.norm_constant_bias = 1
self.parallel_decoder_blocks = False
self.requires_bos = True
self.rope_style = RopeStyle.NEOX
self.keymap = None
self.fused_qkv_key = None
self.mqa = False
self.scale_attn_weights = False
self.pre_post_layernorm = True
self.alternating_swa = True
self.residual_stream_fp32 = True

# StarCoder2

if arch_string == "Starcoder2ForCausalLM":
Expand Down Expand Up @@ -586,6 +654,41 @@ def __init__(self, arch_string, read_config):
self.scale_attn_weights = False
self.logit_scale_basedim = True

# InternLM2

if arch_string == "InternLM2ForCausalLM":
arch_recognized = True
self.layer_keys += \
layer_keys_internlm2_norms + \
layer_keys_internlm2_attn + \
layer_keys_internlm2_mlp
self.expect_keys += \
expect_keys_llama
self.norm_eps_key = "rms_norm_eps"
self.attention_bias_qkv = False
self.attention_bias_o = False
self.mlp_bias = False
self.mlp_gate = True
self.mlp_key_gate = ".feed_forward.w1"
self.mlp_key_up = ".feed_forward.w3"
self.mlp_key_down = ".feed_forward.w2"
self.mlp_act_func = "silu"
self.is_moe = False
self.norm = "rmsnorm"
self.lm_head_key = "lm_head"
self.normalize_embeddings = False
self.norm_key_1 = ".attention_norm"
self.norm_key_2 = ".ffn_norm"
self.norm_constant_bias = 0
self.parallel_decoder_blocks = False
self.requires_bos = False
self.rope_style = RopeStyle.NEOX
self.keymap = internlm2_keymap
self.fused_qkv_key = "wqkv"
self.fused_qkv_altpack = True
self.mqa = False
self.scale_attn_weights = False

# Llama (default + fallback)

if arch_string != "LlamaForCausalLM" and not arch_recognized:
Expand Down Expand Up @@ -637,6 +740,11 @@ def __init__(self, arch_string, read_config):
self.expect_keys.remove(["lm_head"])
self.lm_head_key = "model.embed_tokens"

# Sanity checks

if self.residual_stream_fp32:
assert self.norm_key_1_post and self.norm_key_2_post, \
"FP32 residual stream only implemented for arch with post layernorms"

def make_fused_mlp(self):

Expand Down
Loading

0 comments on commit ca9aecf

Please sign in to comment.