Skip to content

Commit

Permalink
minor cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
csinva committed Dec 7, 2022
1 parent 13f6142 commit 44adb4c
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 34 deletions.
20 changes: 9 additions & 11 deletions imodelsx/embgam/embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def generate_ngrams_list(
parsing: str='',
nlp_chunks=None,
):
"""Get list of ngrams from sentence
"""Get list of ngrams from sentence using a tokenizer
Params
------
Expand Down Expand Up @@ -131,19 +131,17 @@ def embed_and_sum_function(
sentence = example
# seqs = sentence

if isinstance(sentence, str):
seqs = generate_ngrams_list(
sentence, ngrams=ngrams, tokenizer_ngrams=tokenizer_ngrams,
parsing=parsing, nlp_chunks=nlp_chunks, all_ngrams=all_ngrams,
)
elif isinstance(sentence, list):
raise Exception('batched mode not supported')
# seqs = list(map(generate_ngrams_list, sentence))
assert isinstance(sentence, str), 'sentence must be a string (batched mode not supported)'
seqs = generate_ngrams_list(
sentence, ngrams=ngrams, tokenizer_ngrams=tokenizer_ngrams,
parsing=parsing, nlp_chunks=nlp_chunks, all_ngrams=all_ngrams,
)
# seqs = list(map(generate_ngrams_list, sentence))


# maybe a smarter way to deal with pooling here?
seq_len = len(seqs)
if seq_len == 0:
seqs = ["dummy"]
seqs = ["dummy"] # will multiply embedding by 0 so doesn't matter

if 'bert' in checkpoint.lower(): # has up to two keys, 'last_hidden_state', 'pooler_output'
if not hasattr(tokenizer_embeddings, 'pad_token') or tokenizer_embeddings.pad_token is None:
Expand Down
50 changes: 28 additions & 22 deletions imodelsx/embgam/embgam.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,18 +168,23 @@ def cache_linear_coefs(self, X: ArrayLike, model=None, tokenizer_embeddings=None
print('\tNothing to update!')
return

# compute embeddings
"""
# Faster version that needs more memory
tokens = tokenizer(ngrams_list, padding=args.padding,
truncation=True, return_tensors="pt")
tokens = tokens.to(device)
embs = self._get_embs(ngrams_list, model, tokenizer_embeddings)
if self.normalize_embs:
embs = self.normalizer.transform(embs)

output = model(**tokens) # this takes a while....
embs = output['pooler_output'].cpu().detach().numpy()
return embs
# save coefs
coef_embs = self.linear.coef_.squeeze().transpose()
linear_coef = embs @ coef_embs
self.coefs_dict_ = {
**coefs_dict_old,
**{ngrams_list[i]: linear_coef[i]
for i in range(len(ngrams_list))}
}
print('coefs_dict_ len', len(self.coefs_dict_))

def _get_embs(self, ngrams_list, model, tokenizer_embeddings):
"""Get embeddings for a list of ngrams (not summed!)
"""
# Slower way to run things but won't run out of mem
embs = []
for i in tqdm(range(len(ngrams_list))):
tokens = tokenizer_embeddings(
Expand All @@ -191,18 +196,19 @@ def cache_linear_coefs(self, X: ArrayLike, model=None, tokenizer_embeddings=None
emb = emb.mean(axis=1)
embs.append(emb)
embs = np.array(embs).squeeze()
if self.normalize_embs:
embs = self.normalizer.transform(embs)
return embs

# save coefs
coef_embs = self.linear.coef_.squeeze().transpose()
linear_coef = embs @ coef_embs
self.coefs_dict_ = {
**coefs_dict_old,
**{ngrams_list[i]: linear_coef[i]
for i in range(len(ngrams_list))}
}
print('coefs_dict_ len', len(self.coefs_dict_))
"""
# Faster version that needs more memory
tokens = tokenizer(ngrams_list, padding=args.padding,
truncation=True, return_tensors="pt")
tokens = tokens.to(device)
output = model(**tokens) # this takes a while....
embs = output['pooler_output'].cpu().detach().numpy()
return embs
"""


def _get_ngrams_list(self, X):
all_ngrams = set()
Expand Down Expand Up @@ -251,7 +257,7 @@ def _predict_cached(self, X, warn):
n_unseen_ngrams = 0
for x in X:
pred = 0
seqs = imodelsx.embgam.embed.generate_ngrams_list(
seqs = imodelsx.embgam.embed.generate_ngraxms_list(
x,
ngrams=self.ngrams,
tokenizer_ngrams=self.tokenizer_ngrams,
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

setuptools.setup(
name="imodelsx",
version="0.04",
version="0.05",
author="Chandan Singh, John X. Morris",
author_email="[email protected]",
description="Library to explain a dataset in natural language.",
Expand Down

0 comments on commit 44adb4c

Please sign in to comment.