Skip to content

Commit

Permalink
improved next-token logit computations
Browse files Browse the repository at this point in the history
  • Loading branch information
csinva committed Mar 28, 2024
1 parent b048ec6 commit 84b2690
Show file tree
Hide file tree
Showing 2 changed files with 121 additions and 50 deletions.
169 changes: 120 additions & 49 deletions imodelsx/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,10 +314,27 @@ def __call__(
do_sample=False,
use_cache=True,
verbose=False,
return_next_token_prob_scores=False,
target_token_strs: List[str] = None,
return_top_target_token_str: bool = False,
# batch_size=1,
) -> Union[str, List[str]]:
"""Warning: stop is used posthoc but not during generation.
Be careful, caching can take up a lot of memory....
Params
------
return_next_token_prob_scores: bool
If this is true, then the function will return the probability of the next token being each of the target_token_strs
target_token_strs: List[str]
If this is not None and return_next_token_prob_scores is True, then the function will return the probability of the next token being each of the target_token_strs
The output will be a list of dictionaries in this case List[Dict[str, float]]
return_top_target_token_str: bool
If true and above are true, then just return top token of the above
This is a way to constrain the output (but only for 1 token)
This setting caches but the other two (which do not return strings) do not cache
"""
input_is_str = isinstance(prompt, str)
with torch.no_grad():
Expand All @@ -332,7 +349,10 @@ def __call__(
if os.path.exists(cache_file):
if verbose:
print("cached!")
return pkl.load(open(cache_file, "rb"))
try:
return pkl.load(open(cache_file, "rb"))
except:
print('failed to load cache so rerunning...')
if verbose:
print("not cached...")

Expand All @@ -342,26 +362,57 @@ def __call__(
self._tokenizer.pad_token_id = self._tokenizer.eos_token_id
inputs = self._tokenizer(
prompt, return_tensors="pt",
return_attention_mask=True, padding=True,
return_attention_mask=True,
padding=True,
truncation=False,
).to(
self._model.device
) # .input_ids.to("cuda")
# stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=max_tokens)])
# outputs = self._model.generate(input_ids, max_length=max_tokens, stopping_criteria=stopping_criteria)
# print('pad_token', self._tokenizer.pad_token)
).to(self._model.device)

# torch.manual_seed(0)
outputs = self._model.generate(
**inputs,
max_new_tokens=max_new_tokens,
do_sample=do_sample,
# pad_token=self._tokenizer.pad_token,
pad_token_id=self._tokenizer.pad_token_id,
if return_next_token_prob_scores:
outputs = self._model.generate(
**inputs,
max_new_tokens=1,
pad_token_id=self._tokenizer.pad_token_id,
output_logits=True,
return_dict_in_generate=True,
)
next_token_logits = outputs['logits'][0]
next_token_probs = next_token_logits.softmax(
axis=-1).detach().cpu().numpy()

if target_token_strs is not None:
target_token_ids = self._check_target_token_strs(
target_token_strs)
if return_top_target_token_str:
selected_tokens = next_token_probs[:, np.array(
target_token_ids)].squeeze().argmax(axis=-1)
out_strs = [
target_token_strs[selected_tokens[i]]
for i in range(len(selected_tokens))
]
if use_cache:
pkl.dump(out_strs, open(cache_file, "wb"))
return out_strs
else:
out_dict_list = [
{target_token_strs[i]: next_token_probs[prompt_num, target_token_ids[i]]
for i in range(len(target_token_strs))
}
for prompt_num in range(len(prompt))
]
return out_dict_list
else:
return next_token_probs
else:
outputs = self._model.generate(
**inputs,
max_new_tokens=max_new_tokens,
do_sample=do_sample,
pad_token_id=self._tokenizer.pad_token_id,
)
# top_p=0.92,
# temperature=0,
# top_k=0
)
if input_is_str:
out_str = self._tokenizer.decode(
outputs[0], skip_special_tokens=True)
Expand All @@ -381,27 +432,27 @@ def __call__(
pkl.dump(out_strs, open(cache_file, "wb"))
return out_strs

def _get_logit_for_target_token(self, prompt: str, target_token_str: str) -> float:
"""Get logits target_token_str
This is weird when token_output_ids represents multiple tokens
It currently will only take the first token
"""
# Get first token id in target_token_str
target_token_id = self._tokenizer(target_token_str)["input_ids"][0]

# get prob of target token
inputs = self._tokenizer(
prompt,
return_tensors="pt",
return_attention_mask=True,
padding=False,
truncation=False,
).to(self._model.device)
# shape is (batch_size, seq_len, vocab_size)
logits = self._model(**inputs)["logits"].detach().cpu()
# shape is (vocab_size,)
probs_next_token = softmax(logits[0, -1, :].numpy().flatten())
return probs_next_token[target_token_id]
def _check_target_token_strs(self, target_token_strs, override_token_with_first_token_id=False):
# deal with target_token_strs.... ######################
if isinstance(target_token_strs, str):
target_token_strs = [target_token_strs]

target_token_ids = [self._tokenizer(target_token_str)["input_ids"]
for target_token_str in target_token_strs]

# Check that the target token is in the vocab
if override_token_with_first_token_id:
# Get first token id in target_token_str
target_token_ids = [target_token_id[0]
for target_token_id in target_token_ids]
else:
for i in range(len(target_token_strs)):
if len(target_token_ids[i]) > 1:
raise ValueError(
f"target_token_str {target_token_strs[i]} has multiple tokens: " +
str([self._tokenizer.decode(target_token_id)
for target_token_id in target_token_ids[i]]))
return target_token_ids


if __name__ == "__main__":
Expand All @@ -421,16 +472,36 @@ def _get_logit_for_target_token(self, prompt: str, target_token_str: str) -> flo
# model = transformers.LlamaForCausalLM.from_pretrained("chaoyi-wu/PMC_LLAMA_7B")

# llm = get_llm("chaoyi-wu/PMC_LLAMA_7B")
llm = get_llm("llama_65b")
text = llm(
"""Continue this list
- red
- orange
- yellow
- green
-""",
use_cache=False,
)
print(text)
print("\n\n")
print(repr(text))
# llm = get_llm("llama_65b")
# text = llm(
# """Continue this list
# - red
# - orange
# - yellow
# - green
# -""",
# use_cache=False,
# )
# print(text)
# print("\n\n")
# print(repr(text))

# GET LOGITS ###################################
# llm = get_llm("gpt2")
# prompts = ['roses are red, violets are', 'may the force be with']
# # prompts = ['may the force be with', 'so may the light be with']
# target_token_strs = [' blue', ' you']
# ans = llm(prompts, return_next_token_prob_scores=True,
# use_cache=False, target_token_strs=target_token_strs)

# FORCE WORDSSSSSSSSS ##########
llm = get_llm("gpt2")
prompts = ['roses are red, violets are',
'may the force be with', 'trees are usually']
# prompts = ['may the force be with', 'so may the light be with']
target_token_strs = [' green', ' you', 'orange']
llm._check_target_token_strs(target_token_strs)
ans = llm(prompts, use_cache=False,
return_next_token_prob_scores=True, target_token_strs=target_token_strs,
return_top_target_token_str=True)
print('ans', ans)
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@

setuptools.setup(
name="imodelsx",
version="0.4.0",
version="0.4.1",
author="Chandan Singh, John X. Morris, Armin Askari, Divyanshu Aggarwal, Aliyah Hsu, Yuntian Deng",
author_email="[email protected]",
description="Library to explain a dataset in natural language.",
Expand Down

0 comments on commit 84b2690

Please sign in to comment.