Skip to content

Commit

Permalink
refactor auggam -> auglinear
Browse files Browse the repository at this point in the history
  • Loading branch information
csinva committed Dec 13, 2023
1 parent 78b43b1 commit ce229d6
Show file tree
Hide file tree
Showing 9 changed files with 109 additions and 142 deletions.
193 changes: 78 additions & 115 deletions demo_notebooks/aug_imodels.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion imodelsx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
.. include:: ../readme.md
"""

from .auggam.auggam import AugGAMClassifier, AugGAMRegressor
from .auglinear.auglinear import AugLinearClassifier, AugLinearRegressor
from .augtree.augtree import AugTreeClassifier, AugTreeRegressor
from .linear_finetune import LinearFinetuneClassifier, LinearFinetuneRegressor
from .linear_ngram import LinearNgramClassifier, LinearNgramRegressor
Expand Down
File renamed without changes.
31 changes: 16 additions & 15 deletions imodelsx/auggam/auggam.py → imodelsx/auglinear/auglinear.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
"""
Simple scikit-learn interface for Emb-GAM.
Simple scikit-learn interface for Aug-Linear.
Aug-GAM: an Interpretable and Efficient Predictor using Pre-trained Language Models
Chandan Singh & Jianfeng Gao
Augmenting Interpretable Models with LLMs during Training
Chandan Singh, Armin Askari, Rich Caruana, Jianfeng Gao
https://arxiv.org/abs/2209.11799
"""
from numpy.typing import ArrayLike
Expand All @@ -15,7 +14,7 @@
from sklearn.utils.validation import check_is_fitted
from sklearn.preprocessing import StandardScaler
import transformers
import imodelsx.auggam.embed
import imodelsx.auglinear.embed
from tqdm import tqdm
import os
import os.path
Expand All @@ -29,7 +28,7 @@
device = "cuda" if torch.cuda.is_available() else "cpu"


class AugGAM(BaseEstimator):
class AugLinear(BaseEstimator):
def __init__(
self,
checkpoint: str = "bert-base-uncased",
Expand All @@ -41,11 +40,10 @@ def __init__(
random_state=None,
normalize_embs=False,
cache_embs_dir: str = None,
cache_coefs_dir: str = None,
fit_with_ngram_decomposition=True,
instructor_prompt=None,
):
"""AugGAM-GAM Class - use either AugGAMClassifier or AugGAMRegressor rather than initializing this class directly.
"""AugLinear Class - use either AugLinearClassifier or AugLinearRegressor rather than initializing this class directly.
Parameters
----------
Expand All @@ -68,7 +66,7 @@ def __init__(
cache_embs_dir: str = None,
if not None, directory to save embeddings into
fit_with_ngram_decomposition
whether to fit to emb-gam style (using sum of embeddings of each ngram)
whether to fit to aug-linear style (using sum of embeddings of each ngram)
if False, fits a typical model and uses ngram decomposition only for prediction / testing
Usually, setting this to False will considerably impede performance
instructor_prompt
Expand Down Expand Up @@ -166,7 +164,7 @@ def fit(
def _get_embs_summed(self, X, model, tokenizer_embeddings, batch_size):
embs = []
for x in tqdm(X):
emb = imodelsx.auggam.embed.embed_and_sum_function(
emb = imodelsx.auglinear.embed.embed_and_sum_function(
x,
model=model,
ngrams=self.ngrams,
Expand Down Expand Up @@ -276,6 +274,7 @@ def normalize_embs(embs, renormalize_embs):
linear_coef = embs @ coef_embs

# save coefs
linear_coef = linear_coef.squeeze()
self.coefs_dict_ = {
**coefs_dict_old,
**{ngrams_list[i]: linear_coef[i] for i in range(len(ngrams_list))},
Expand All @@ -302,7 +301,7 @@ def _get_embs(self, ngrams_list, model, tokenizer_embeddings, batch_size):
embs = np.vstack(embs).squeeze()

else:
embs = imodelsx.auggam.embed.embed_and_sum_function(
embs = imodelsx.auglinear.embed.embed_and_sum_function(
ngrams_list,
model=model,
ngrams=None,
Expand Down Expand Up @@ -347,12 +346,14 @@ def predict(self, X, warn=True):
"""For regression returns continuous output.
For classification, returns discrete output.
"""

check_is_fitted(self)
preds = self._predict_cached(X, warn=warn)
if isinstance(self, RegressorMixin):
return preds
elif isinstance(self, ClassifierMixin):
if preds.ndim > 1: # multiclass classification
# multiclass classification
if preds.ndim > 1:
return np.argmax(preds, axis=1)
else:
return (preds + self.linear.intercept_ > 0).astype(int)
Expand Down Expand Up @@ -398,14 +399,14 @@ def _predict_cached(self, X, warn):
For better performance, call cache_linear_coefs on the test dataset \
before calling predict."
)
return np.array(preds)
return np.array(preds).squeeze()


class AugGAMRegressor(AugGAM, RegressorMixin):
class AugLinearRegressor(AugLinear, RegressorMixin):
...


class AugGAMClassifier(AugGAM, ClassifierMixin):
class AugLinearClassifier(AugLinear, ClassifierMixin):
...


Expand Down
File renamed without changes.
6 changes: 3 additions & 3 deletions readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
| AutoPrompt | ㅤㅤ[🗂️](), [🔗](https://github.com/ucinlp/autoprompt), [📄](https://arxiv.org/abs/2010.15980) | Explanation<br/>+ Steering | Find a natural-language prompt<br/>using input-gradients (⌛ In progress)|
| D3 | [🗂️](http://csinva.io/imodelsX/d3/d3.html#imodelsx.d3.d3.explain_dataset_d3), [🔗](https://github.com/ruiqi-zhong/DescribeDistributionalDifferences), [📄](https://arxiv.org/abs/2201.12323), [📖](https://github.com/csinva/imodelsX/blob/master/demo_notebooks/d3.ipynb) | Explanation | Explain the difference between two distributions |
| SASC | ㅤㅤ[🗂️](https://csinva.io/imodelsX/sasc/api.html), [🔗](https://github.com/microsoft/automated-explanations), [📄](https://arxiv.org/abs/2305.09863) | Explanation | Explain a black-box text module<br/>using an LLM (*Official*) |
| Aug-GAM | [🗂️](https://csinva.io/imodelsX/auggam/auggam.html), [🔗](https://github.com/microsoft/aug-models), [📄](https://www.nature.com/articles/s41467-023-43713-1), [📖](https://github.com/csinva/imodelsX/blob/master/demo_notebooks/aug_imodels.ipynb) | Linear model | Fit better linear model using an LLM<br/>to extract embeddings (*Official*) |
| Aug-Linear | [🗂️](https://csinva.io/imodelsX/auglinear/auglinear.html), [🔗](https://github.com/microsoft/aug-models), [📄](https://www.nature.com/articles/s41467-023-43713-1), [📖](https://github.com/csinva/imodelsX/blob/master/demo_notebooks/aug_imodels.ipynb) | Linear model | Fit better linear model using an LLM<br/>to extract embeddings (*Official*) |
| Aug-Tree | [🗂️](https://csinva.io/imodelsX/augtree/augtree.html), [🔗](https://github.com/microsoft/aug-models), [📄](https://www.nature.com/articles/s41467-023-43713-1), [📖](https://github.com/csinva/imodelsX/blob/master/demo_notebooks/aug_imodels.ipynb) | Decision tree | Fit better decision tree using an LLM<br/>to expand features (*Official*) |

<p align="center">
Expand Down Expand Up @@ -167,7 +167,7 @@ explanation_dict = explain_module_sasc(
Use these just a like a scikit-learn model. During training, they fit better features via LLMs, but at test-time they are extremely fast and completely transparent.

```python
from imodelsx import AugGAMClassifier, AugTreeClassifier, AugGAMRegressor, AugTreeRegressor
from imodelsx import AugLinearClassifier, AugTreeClassifier, AugLinearRegressor, AugTreeRegressor
import datasets
import numpy as np

Expand All @@ -178,7 +178,7 @@ dset_val = datasets.load_dataset('rotten_tomatoes')['validation']
dset_val = dset_val.select(np.random.choice(len(dset_val), size=300, replace=False))

# fit model
m = AugGAMClassifier(
m = AugLinearClassifier(
checkpoint='textattack/distilbert-base-uncased-rotten-tomatoes',
ngrams=2, # use bigrams
)
Expand Down
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@

setuptools.setup(
name="imodelsx",
version="0.4.2",
author="Chandan Singh, John X. Morris, Armin Askari, Divyanshu Aggarwal, Aliyah Hsu",
version="0.4.0",
author="Chandan Singh, John X. Morris, Armin Askari, Divyanshu Aggarwal, Aliyah Hsu, Yuntian Deng",
author_email="[email protected]",
description="Library to explain a dataset in natural language.",
long_description=long_description,
Expand Down
5 changes: 3 additions & 2 deletions tests/test_auggam.py → tests/test_auglinear.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from imodelsx import explain_dataset_iprompt, get_add_two_numbers_dataset, AugGAMClassifier
from imodelsx import explain_dataset_iprompt, get_add_two_numbers_dataset, AugLinearClassifier


def test_auggam():
m = AugGAMClassifier()
m = AugLinearClassifier()
10 changes: 6 additions & 4 deletions tests/test_pipeline.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from imodelsx import AugGAMClassifier
from imodelsx import AugLinearClassifier
import datasets
import numpy as np

Expand All @@ -11,7 +11,7 @@
len(dset_val), size=10, replace=False))

# fit model
m = AugGAMClassifier(
m = AugLinearClassifier(
checkpoint='textattack/distilbert-base-uncased-rotten-tomatoes',
ngrams=2,
all_ngrams=True, # also use lower-order ngrams
Expand All @@ -27,8 +27,10 @@
# check results when varying batch size
m.fit(dset['text'], dset['label'], batch_size=16)
preds_check = m.predict(dset_val['text'])
assert np.allclose(preds, preds_check), 'predictions should be same when varying batch size'
assert np.allclose(np.array(list(m.coefs_dict_.values())), coefs_orig), 'coefs should be same when varying batch size'
assert np.allclose(
preds, preds_check), 'predictions should be same when varying batch size'
assert np.allclose(np.array(list(m.coefs_dict_.values())),
coefs_orig), 'coefs should be same when varying batch size'

# interpret
print('Total ngram coefficients: ', len(m.coefs_dict_))
Expand Down

0 comments on commit ce229d6

Please sign in to comment.