Skip to content

Commit e8a544f

Browse files
committed
Update logit mixing example, rename to CFG
1 parent b2e3982 commit e8a544f

File tree

1 file changed

+10
-8
lines changed

1 file changed

+10
-8
lines changed

example_logit_mixing.py renamed to example_cfg.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from tokenizer import ExLlamaTokenizer
33
from generator import ExLlamaGenerator
44
import torch
5+
import torch.nn.functional as F
56
import os, glob
67
import cuda_ext
78

@@ -20,7 +21,6 @@
2021

2122
config = ExLlamaConfig(model_config_path) # create config from config.json
2223
config.model_path = model_path # supply path to model weights file
23-
config.max_input_len = 16
2424

2525
model = ExLlama(config) # create ExLlama instance and load the weights
2626
tokenizer = ExLlamaTokenizer(tokenizer_path) # create tokenizer from tokenizer model file
@@ -31,10 +31,10 @@
3131
# Configure generator
3232

3333
generator.settings.token_repetition_penalty_max = 1.15
34-
generator.settings.temperature = 0.75
34+
generator.settings.temperature = 0.95
3535
generator.settings.top_k = 40
36-
generator.settings.top_p = 0.65
37-
# generator.settings.typical = 0.5
36+
generator.settings.top_p = 0.75
37+
# generator.settings.typical = 0.95
3838

3939
# Prompts to mix
4040

@@ -46,28 +46,30 @@
4646

4747
f2 = \
4848
"""[INST] <<SYS>>
49-
You are a rude and obnoxious assistant. You hate everything and everyone.
5049
<</SYS>>
50+
You are a rude and obnoxious assistant. You hate everything and everyone.
5151
{prompt}[/INST]"""
5252

53+
5354
prompts = \
5455
[
5556
f1.replace("{prompt}", "Tell me about Homer Simpson"),
5657
f2.replace("{prompt}", "Tell me about Homer Simpson"),
5758
]
5859

59-
def mixed_generation(prompts, alpha, max_new_tokens):
60+
def generate_cfg(prompts, alpha, max_new_tokens):
6061

6162
ids, mask = tokenizer.encode(prompts, return_mask = True)
6263
generator.gen_begin(ids, mask = mask)
6364

6465
# Sampling loop
6566

66-
for i in range(max_new_tokens):
67+
for _ in range(max_new_tokens):
6768

6869
logits = model.forward(generator.sequence[:, -1:], cache, input_mask = mask)
6970
generator.apply_rep_penalty(logits)
7071

72+
logits = F.log_softmax(logits, dim = -1)
7173
logits_mixed = (1 - alpha) * logits[0] + alpha * logits[1]
7274

7375
sampled_token, _ = generator.sample_current(logits_mixed)
@@ -86,5 +88,5 @@ def mixed_generation(prompts, alpha, max_new_tokens):
8688
print(f"--------------------------------------")
8789
print(f"alpha = {alpha:.1f}")
8890
print(f"--------------------------------------")
89-
output = mixed_generation(prompts, alpha, 200)
91+
output = generate_cfg(prompts, alpha, 200)
9092
print(output[len(prompts[0]):].strip())

0 commit comments

Comments
 (0)