Skip to content

Commit

Permalink
Merge pull request #6 from csinva/divyanshuaggarwal-da_min_ngram
Browse files Browse the repository at this point in the history
Min frequency ngrams from @divyanshuaggarwal
  • Loading branch information
csinva authored Apr 16, 2023
2 parents 51754dd + 18a1fbe commit f306381
Show file tree
Hide file tree
Showing 8 changed files with 245 additions and 78 deletions.
140 changes: 107 additions & 33 deletions docs/auggam/auggam.html

Large diffs are not rendered by default.

137 changes: 102 additions & 35 deletions docs/auggam/embed.html
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,14 @@ <h1 class="title">Module <code>imodelsx.auggam.embed</code></h1>
<pre><code class="python">from transformers import BertModel, DistilBertModel
from transformers import AutoModelForCausalLM
from os.path import join as oj
from datasets import Dataset
from tqdm import tqdm
import torch
import numpy as np
from torch.utils.data import DataLoader
import imodelsx.util


def get_model(checkpoint):
if &#39;distilbert&#39; in checkpoint.lower():
model = DistilBertModel.from_pretrained(checkpoint)
Expand Down Expand Up @@ -79,7 +84,8 @@ <h1 class="title">Module <code>imodelsx.auggam.embed</code></h1>
checkpoint: str,
dataset_key_text: str = None,
layer: str = &#39;last_hidden_state&#39;,
padding: bool = True,
padding: str = &#34;max_length&#34;,
batch_size: int = 8,
parsing: str = &#39;&#39;,
nlp_chunks=None,
all_ngrams: bool = False,
Expand Down Expand Up @@ -134,27 +140,57 @@ <h1 class="title">Module <code>imodelsx.auggam.embed</code></h1>
tokenizer_embeddings.pad_token = tokenizer_embeddings.eos_token
tokens = tokenizer_embeddings(seqs, padding=padding,
truncation=True, return_tensors=&#34;pt&#34;)
tokens = tokens.to(model.device)
output = model(**tokens)
if layer == &#39;pooler_output&#39;:
embs = output[&#39;pooler_output&#39;].cpu().detach().numpy()
elif layer == &#39;last_hidden_state_mean&#39; or layer == &#39;last_hidden_state&#39;:
embs = output[&#39;last_hidden_state&#39;].cpu().detach().numpy()
embs = embs.mean(axis=1)

embs = []

ds = Dataset.from_dict(tokens).with_format(&#34;torch&#34;)

for batch in DataLoader(ds, batch_size=batch_size, shuffle=False):
batch = {k: v.to(model.device) for k, v in batch.items()}

with torch.no_grad():
output = model(**batch)
torch.cuda.empty_cache()

if layer == &#39;pooler_output&#39;:
emb = output[&#39;pooler_output&#39;].cpu().detach().numpy()
elif layer == &#39;last_hidden_state_mean&#39; or layer == &#39;last_hidden_state&#39;:
emb = output[&#39;last_hidden_state&#39;].cpu().detach().numpy()
emb = emb.mean(axis=1)

embs.append(emb)

embs = np.concatenate(embs)

elif &#39;gpt&#39; in checkpoint.lower():
tokens = preprocess_gpt_token_batch(seqs, tokenizer_embeddings)
tokens = tokens.to(model.device)
output = model(**tokens)

# tuple of (layer x (batch_size, seq_len, hidden_size))
h = output[&#39;hidden_states&#39;]
# (batch_size, seq_len, hidden_size)
embs = h[0].cpu().detach().numpy()
embs = embs.mean(axis=1) # (batch_size, hidden_size)

embs = []

ds = Dataset.from_dict(tokens).with_format(&#34;torch&#34;)

for batch in DataLoader(ds, batch_size=batch_size, shuffle=False):
batch = {k: v.to(model.device) for k, v in batch.items()}

with torch.no_grad():
output = model(**batch)
torch.cuda.empty_cache()

# tuple of (layer x (batch_size, seq_len, hidden_size))
h = output[&#39;hidden_states&#39;]
# (batch_size, seq_len, hidden_size)
emb = h[0].cpu().detach().numpy()
emb = emb.mean(axis=1) # (batch_size, hidden_size)

embs.append(emb)

embs = np.concatenate(embs)

elif checkpoint.startswith(&#39;hkunlp/instructor&#39;):
if instructor_prompt is None:
instructor_prompt = &#34;Represent the short phrase for sentiment classification: &#34;
embs = model.encode([[instructor_prompt, x_i] for x_i in seqs], batch_size=32)
embs = model.encode([[instructor_prompt, x_i]
for x_i in seqs], batch_size=batch_size)

# sum over the embeddings
embs = embs.sum(axis=0).reshape(1, -1)
Expand All @@ -172,7 +208,7 @@ <h1 class="title">Module <code>imodelsx.auggam.embed</code></h1>
<h2 class="section-title" id="header-functions">Functions</h2>
<dl>
<dt id="imodelsx.auggam.embed.embed_and_sum_function"><code class="name flex">
<span>def <span class="ident">embed_and_sum_function</span></span>(<span>example, model, ngrams: int, tokenizer_embeddings, tokenizer_ngrams, checkpoint: str, dataset_key_text: str = None, layer: str = 'last_hidden_state', padding: bool = True, parsing: str = '', nlp_chunks=None, all_ngrams: bool = False, fit_with_ngram_decomposition: bool = True, instructor_prompt: str = None)</span>
<span>def <span class="ident">embed_and_sum_function</span></span>(<span>example, model, ngrams: int, tokenizer_embeddings, tokenizer_ngrams, checkpoint: str, dataset_key_text: str = None, layer: str = 'last_hidden_state', padding: str = 'max_length', batch_size: int = 8, parsing: str = '', nlp_chunks=None, all_ngrams: bool = False, fit_with_ngram_decomposition: bool = True, instructor_prompt: str = None)</span>
</code></dt>
<dd>
<div class="desc"><p>Get summed embeddings for a single example</p>
Expand Down Expand Up @@ -206,7 +242,8 @@ <h2 id="params">Params</h2>
checkpoint: str,
dataset_key_text: str = None,
layer: str = &#39;last_hidden_state&#39;,
padding: bool = True,
padding: str = &#34;max_length&#34;,
batch_size: int = 8,
parsing: str = &#39;&#39;,
nlp_chunks=None,
all_ngrams: bool = False,
Expand Down Expand Up @@ -261,27 +298,57 @@ <h2 id="params">Params</h2>
tokenizer_embeddings.pad_token = tokenizer_embeddings.eos_token
tokens = tokenizer_embeddings(seqs, padding=padding,
truncation=True, return_tensors=&#34;pt&#34;)
tokens = tokens.to(model.device)
output = model(**tokens)
if layer == &#39;pooler_output&#39;:
embs = output[&#39;pooler_output&#39;].cpu().detach().numpy()
elif layer == &#39;last_hidden_state_mean&#39; or layer == &#39;last_hidden_state&#39;:
embs = output[&#39;last_hidden_state&#39;].cpu().detach().numpy()
embs = embs.mean(axis=1)

embs = []

ds = Dataset.from_dict(tokens).with_format(&#34;torch&#34;)

for batch in DataLoader(ds, batch_size=batch_size, shuffle=False):
batch = {k: v.to(model.device) for k, v in batch.items()}

with torch.no_grad():
output = model(**batch)
torch.cuda.empty_cache()

if layer == &#39;pooler_output&#39;:
emb = output[&#39;pooler_output&#39;].cpu().detach().numpy()
elif layer == &#39;last_hidden_state_mean&#39; or layer == &#39;last_hidden_state&#39;:
emb = output[&#39;last_hidden_state&#39;].cpu().detach().numpy()
emb = emb.mean(axis=1)

embs.append(emb)

embs = np.concatenate(embs)

elif &#39;gpt&#39; in checkpoint.lower():
tokens = preprocess_gpt_token_batch(seqs, tokenizer_embeddings)
tokens = tokens.to(model.device)
output = model(**tokens)

# tuple of (layer x (batch_size, seq_len, hidden_size))
h = output[&#39;hidden_states&#39;]
# (batch_size, seq_len, hidden_size)
embs = h[0].cpu().detach().numpy()
embs = embs.mean(axis=1) # (batch_size, hidden_size)

embs = []

ds = Dataset.from_dict(tokens).with_format(&#34;torch&#34;)

for batch in DataLoader(ds, batch_size=batch_size, shuffle=False):
batch = {k: v.to(model.device) for k, v in batch.items()}

with torch.no_grad():
output = model(**batch)
torch.cuda.empty_cache()

# tuple of (layer x (batch_size, seq_len, hidden_size))
h = output[&#39;hidden_states&#39;]
# (batch_size, seq_len, hidden_size)
emb = h[0].cpu().detach().numpy()
emb = emb.mean(axis=1) # (batch_size, hidden_size)

embs.append(emb)

embs = np.concatenate(embs)

elif checkpoint.startswith(&#39;hkunlp/instructor&#39;):
if instructor_prompt is None:
instructor_prompt = &#34;Represent the short phrase for sentiment classification: &#34;
embs = model.encode([[instructor_prompt, x_i] for x_i in seqs], batch_size=32)
embs = model.encode([[instructor_prompt, x_i]
for x_i in seqs], batch_size=batch_size)

# sum over the embeddings
embs = embs.sum(axis=0).reshape(1, -1)
Expand Down
9 changes: 2 additions & 7 deletions docs/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -83,17 +83,12 @@
<td>Black-box model</td>
<td>Finetune a single linear layer<br/>on top of LLM embeddings</td>
</tr>
<tr>
<td style="text-align: left;">(Coming soon!)</td>
<td></td>
<td></td>
<td>We plan to support other interpretable models like <a href="https://arxiv.org/abs/2205.12548">RLPrompt</a>, <a href="https://arxiv.org/abs/2007.04612">CBMs</a>, <a href="https://proceedings.neurips.cc/paper/2021/hash/251bd0442dfcc53b5a761e050f8022b8-Abstract.html">NAMs</a>, and <a href="https://arxiv.org/abs/2004.00221">NBDT</a></td>
</tr>
</tbody>
</table>
<p align="center">
Demo notebooks <a href="https://github.com/csinva/imodelsX/tree/master/demo_notebooks">📖</a>, Doc <a href="https://csinva.io/imodelsX/">🗂️</a>, Reference code implementation 🔗, Research paper 📄
<a href="https://github.com/csinva/imodelsX/tree/master/demo_notebooks">📖</a>Demo notebooks &emsp; <a href="https://csinva.io/imodelsX/">🗂️</a> Doc &emsp; 🔗 Reference code &emsp; 📄 Research paper
</br>
⌛ We plan to support other interpretable algorithms like <a href="https://arxiv.org/abs/2205.12548">RLPrompt</a>, <a href="https://arxiv.org/abs/2007.04612">CBMs</a>, and <a href="https://arxiv.org/abs/2004.00221">NBDT</a>. If you want to contribute an algorithm, feel free to open a PR 😄
</p>
<h1 id="quickstart">Quickstart</h1>
<p><strong>Installation</strong>: <code>pip install <a title="imodelsx" href="#imodelsx">imodelsx</a></code> (or, for more control, clone and install from source)</p>
Expand Down
21 changes: 19 additions & 2 deletions docs/util.html
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ <h1 class="title">Module <code>imodelsx.util</code></h1>
from transformers import pipeline
import datasets
import numpy as np
from collections import Counter


def generate_ngrams_list(
Expand All @@ -42,6 +43,7 @@ <h1 class="title">Module <code>imodelsx.util</code></h1>
parsing: str = &#39;&#39;,
nlp_chunks=None,
pad_starting_ngrams=False,
min_frequency=1,
):
&#34;&#34;&#34;Get list of ngrams from sentence using a tokenizer

Expand All @@ -55,6 +57,8 @@ <h1 class="title">Module <code>imodelsx.util</code></h1>
if all_ngrams=False, then pad starting ngrams with shorter length ngrams
so that length of ngrams_list is the same as the initial sequence
e.g. for ngrams=3 [&#34;the&#34;, &#34;the quick&#34;, &#34;the quick brown&#34;, &#34;quick brown fox&#34;, &#34;brown fox jumps&#34;, ...]
min_frequency: int
minimum frequency to be considered for the ngrams_list
&#34;&#34;&#34;

seqs = []
Expand Down Expand Up @@ -96,6 +100,10 @@ <h1 class="title">Module <code>imodelsx.util</code></h1>
assert all_ngrams is False, &#34;pad_starting_ngrams only works when all_ngrams=False&#34;
seqs_init = [&#39; &#39;.join(unigrams_list[:ngram_length]) for ngram_length in range(1, ngrams)]
seqs = seqs_init + seqs

freqs = Counter(seqs)

seqs = [seq for seq, freq in freqs.items() if freq &gt;= min_frequency]

return seqs

Expand Down Expand Up @@ -177,7 +185,7 @@ <h1 class="title">Module <code>imodelsx.util</code></h1>
<h2 class="section-title" id="header-functions">Functions</h2>
<dl>
<dt id="imodelsx.util.generate_ngrams_list"><code class="name flex">
<span>def <span class="ident">generate_ngrams_list</span></span>(<span>sentence: str, ngrams: int, tokenizer_ngrams=None, all_ngrams=False, parsing: str = '', nlp_chunks=None, pad_starting_ngrams=False)</span>
<span>def <span class="ident">generate_ngrams_list</span></span>(<span>sentence: str, ngrams: int, tokenizer_ngrams=None, all_ngrams=False, parsing: str = '', nlp_chunks=None, pad_starting_ngrams=False, min_frequency=1)</span>
</code></dt>
<dd>
<div class="desc"><p>Get list of ngrams from sentence using a tokenizer</p>
Expand All @@ -189,7 +197,9 @@ <h2 id="params">Params</h2>
pad_starting_ngrams: bool
if all_ngrams=False, then pad starting ngrams with shorter length ngrams
so that length of ngrams_list is the same as the initial sequence
e.g. for ngrams=3 ["the", "the quick", "the quick brown", "quick brown fox", "brown fox jumps", &hellip;]</p></div>
e.g. for ngrams=3 ["the", "the quick", "the quick brown", "quick brown fox", "brown fox jumps", &hellip;]
min_frequency: int
minimum frequency to be considered for the ngrams_list</p></div>
<details class="source">
<summary>
<span>Expand source code</span>
Expand All @@ -202,6 +212,7 @@ <h2 id="params">Params</h2>
parsing: str = &#39;&#39;,
nlp_chunks=None,
pad_starting_ngrams=False,
min_frequency=1,
):
&#34;&#34;&#34;Get list of ngrams from sentence using a tokenizer

Expand All @@ -215,6 +226,8 @@ <h2 id="params">Params</h2>
if all_ngrams=False, then pad starting ngrams with shorter length ngrams
so that length of ngrams_list is the same as the initial sequence
e.g. for ngrams=3 [&#34;the&#34;, &#34;the quick&#34;, &#34;the quick brown&#34;, &#34;quick brown fox&#34;, &#34;brown fox jumps&#34;, ...]
min_frequency: int
minimum frequency to be considered for the ngrams_list
&#34;&#34;&#34;

seqs = []
Expand Down Expand Up @@ -256,6 +269,10 @@ <h2 id="params">Params</h2>
assert all_ngrams is False, &#34;pad_starting_ngrams only works when all_ngrams=False&#34;
seqs_init = [&#39; &#39;.join(unigrams_list[:ngram_length]) for ngram_length in range(1, ngrams)]
seqs = seqs_init + seqs

freqs = Counter(seqs)

seqs = [seq for seq, freq in freqs.items() if freq &gt;= min_frequency]

return seqs</code></pre>
</details>
Expand Down
5 changes: 5 additions & 0 deletions imodelsx/auggam/auggam.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def __init__(
layer: str = 'last_hidden_state',
ngrams: int = 2,
all_ngrams: bool = False,
min_frequency: int = 1,
tokenizer_ngrams=None,
random_state=None,
normalize_embs=False,
Expand All @@ -54,6 +55,8 @@ def __init__(
Order of ngrams to extract. 1 for unigrams, 2 for bigrams, etc.
all_ngrams
Whether to use all order ngrams <= ngrams argument
min_frequency
minimum frequency of ngrams to be kept in the ngrams list.
tokenizer_ngrams
if None, defaults to spacy English tokenizer
random_state
Expand All @@ -76,6 +79,7 @@ def __init__(
self.layer = layer
self.random_state = random_state
self.all_ngrams = all_ngrams
self.min_frequency = min_frequency
self.normalize_embs = normalize_embs
self.fit_with_ngram_decomposition = fit_with_ngram_decomposition
self.instructor_prompt = instructor_prompt
Expand Down Expand Up @@ -284,6 +288,7 @@ def _get_ngrams_list(self, X):
ngrams=self.ngrams,
tokenizer_ngrams=self.tokenizer_ngrams,
all_ngrams=self.all_ngrams,
min_frequency=self.min_frequency
)
all_ngrams |= set(seqs)
return sorted(list(all_ngrams))
Expand Down
8 changes: 8 additions & 0 deletions imodelsx/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from transformers import pipeline
import datasets
import numpy as np
from collections import Counter


def generate_ngrams_list(
Expand All @@ -16,6 +17,7 @@ def generate_ngrams_list(
parsing: str = '',
nlp_chunks=None,
pad_starting_ngrams=False,
min_frequency=1,
):
"""Get list of ngrams from sentence using a tokenizer
Expand All @@ -29,6 +31,8 @@ def generate_ngrams_list(
if all_ngrams=False, then pad starting ngrams with shorter length ngrams
so that length of ngrams_list is the same as the initial sequence
e.g. for ngrams=3 ["the", "the quick", "the quick brown", "quick brown fox", "brown fox jumps", ...]
min_frequency: int
minimum frequency to be considered for the ngrams_list
"""

seqs = []
Expand Down Expand Up @@ -70,6 +74,10 @@ def generate_ngrams_list(
assert all_ngrams is False, "pad_starting_ngrams only works when all_ngrams=False"
seqs_init = [' '.join(unigrams_list[:ngram_length]) for ngram_length in range(1, ngrams)]
seqs = seqs_init + seqs

freqs = Counter(seqs)

seqs = [seq for seq, freq in freqs.items() if freq >= min_frequency]

return seqs

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

setuptools.setup(
name="imodelsx",
version="0.20",
version="0.21",
author="Chandan Singh, John X. Morris, Armin Askari",
author_email="[email protected]",
description="Library to explain a dataset in natural language.",
Expand Down
1 change: 1 addition & 0 deletions tests/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
checkpoint='textattack/distilbert-base-uncased-rotten-tomatoes',
ngrams=2,
all_ngrams=True, # also use lower-order ngrams
min_frequency=1
)
m.fit(dset['text'], dset['label'], batch_size=8)

Expand Down

0 comments on commit f306381

Please sign in to comment.