Skip to content

Commit

Permalink
Merge pull request #34 from BiomedSciAI/torch_encoder
Browse files Browse the repository at this point in the history
Torch encoder
  • Loading branch information
yoavkt authored Aug 18, 2024
2 parents d291ff7 + 3188074 commit 8812ddf
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 3 deletions.
2 changes: 2 additions & 0 deletions gene_benchmark/deserialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
NCBIDescriptor,
)
from gene_benchmark.encoder import (
BERTEncoder,
MultiEntityEncoder,
PreComputedEncoder,
SentenceTransformerEncoder,
Expand Down Expand Up @@ -83,4 +84,5 @@ def get_gene_disease_multi_encoder(
"BasePairDescriptor": BasePairDescriptor,
"Multilayer_Perceptron_classifier": MLPClassifier,
"Multilayer_Perceptron_regressor": MLPRegressor,
"BERTEncoder": BERTEncoder,
}
47 changes: 44 additions & 3 deletions gene_benchmark/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@

import numpy as np
import pandas as pd
import torch
from sentence_transformers import SentenceTransformer
from transformers import AutoModel, AutoTokenizer
from transformers.models.bert.configuration_bert import BertConfig

from .descriptor import add_prefix_to_dict

Expand Down Expand Up @@ -172,7 +175,7 @@ def __init__(self, embedding_model_name) -> None:
def encode(
self, entities, allow_missing=True, randomize_missing=True, random_len=None
):
# A None may couse an issue in sentence_transformers/models/Transformer.py#L121
# A None may cause an issue in sentence_transformers/models/Transformer.py#L121
unique_entities = list(filter(None, self._get_unique_entities(entities)))
if len(unique_entities) > 0:
unique_encodings = self._get_encoding(
Expand Down Expand Up @@ -368,7 +371,7 @@ class SentenceTransformerEncoder(SingleEncoder):
def __init__(
self,
encoder_model_name: str = None,
encoder_model: pd.DataFrame = None,
encoder_model: SentenceTransformer = None,
show_progress_bar: bool = True,
batch_size: int = 32,
):
Expand Down Expand Up @@ -399,7 +402,7 @@ def _get_encoding(self, entities, **kwargs):
"""
assert (
None not in entities
), "A downstram bug will crash on encoding None sometimes, so there should never be a None here."
), "A downstream bug will crash on encoding None sometimes, so there should never be a None here."

encodings = self.encoder.encode(
entities,
Expand All @@ -413,3 +416,41 @@ def summary(self):
if self.num_of_missing:
summary_dict["num_of_missing"] = self.num_of_missing
return summary_dict


class BERTEncoder(SingleEncoder):
"""encode a list of descriptions into numeric vectors using transformers BERT encoders."""

def __init__(
self,
encoder_model_name: str = None,
tokenizer_name: str = None,
trust_remote_code: bool = False,
):
config = BertConfig.from_pretrained(encoder_model_name)
self.encoder = AutoModel.from_pretrained(
encoder_model_name, trust_remote_code=trust_remote_code, config=config
)

self.tokenizer = AutoTokenizer.from_pretrained(
tokenizer_name, trust_remote_code=trust_remote_code
)
self.encoder_model_name = encoder_model_name
self.tokenizer_name = tokenizer_name
self.trust_remote_code = trust_remote_code
super().__init__(encoder_model_name)

def _get_encoding(self, entities, **kwargs):
vec_list = []
for ent in entities:
inputs = self.tokenizer(ent, return_tensors="pt")["input_ids"]
hidden_states = self.encoder(inputs)[0]
vec_list.append(torch.mean(hidden_states[0], dim=0).detach())
return np.array(vec_list)

def summary(self):
summary_dict = super().summary()
if self.num_of_missing:
summary_dict["num_of_missing"] = self.num_of_missing
summary_dict["tokenizer_name"] = self.tokenizer_name
return summary_dict
13 changes: 13 additions & 0 deletions gene_benchmark/tests/test_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
NCBIDescriptor,
)
from gene_benchmark.encoder import (
BERTEncoder,
MultiEntityEncoder,
PreComputedEncoder,
SentenceTransformerEncoder,
Expand Down Expand Up @@ -413,3 +414,15 @@ def test_multi_entity_mis_col_error(self):
ml_enc = MultiEntityEncoder(enc_dict)
with pytest.raises(Exception, match="columns which are not in the encoding"):
ml_enc.encode(to_encode)

@unittest.skip(
"Following fails when gpu is activated but there is a issue with triton flash attention "
)
def test_TransformerEncoder(self):
model = "zhihan1996/DNABERT-2-117M"
encoder = BERTEncoder(model, model, trust_remote_code=True)
gene1 = "ACGTAGCATCGGATCTATCTATCGACACTTGGTTATCGATCTACGAGCATCTCGTTAGC"
gene2 = "GATTACA"
encoded = encoder.encode(pd.Series([gene1, gene2]))
assert encoded.shape[0] == 2
assert encoded.sum()[0] > 0.00001
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ dependencies = [
"sentence_transformers",
"scikit-learn",
"click",
"einops"
]

[project.optional-dependencies]
Expand Down

0 comments on commit 8812ddf

Please sign in to comment.