Skip to content

Commit

Permalink
Merge pull request #88 from kkovary/87-pretrained-hf-transformer-with…
Browse files Browse the repository at this point in the history
…-gpu-device

PretrainedHFTransformer can use GPUs
  • Loading branch information
maclandrol authored Nov 17, 2023
2 parents f8579b3 + 2841b42 commit c4eb5df
Showing 1 changed file with 13 additions and 7 deletions.
20 changes: 13 additions & 7 deletions molfeat/trans/pretrained/hf_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,12 @@ def save(cls, model: HFExperiment, path: str, clean_up: bool = False):
return path

@classmethod
def load(cls, path: str, model_class=None):
def load(cls, path: str, model_class=None, device: str = "cpu"):
"""Load a model from the given path
Args:
path: Path to the model to load
model_class: optional model class to provide if the model should be loaded with a specific class
device: the device to load the model on ("cpu" or "cuda")
"""
if not dm.fs.is_local_path(path):
local_path = tempfile.mkdtemp()
Expand All @@ -85,7 +86,7 @@ def load(cls, path: str, model_class=None):
)
else:
model_class = AutoModel
model = model_class.from_pretrained(local_path)
model = model_class.from_pretrained(local_path).to(device)
tokenizer = AutoTokenizer.from_pretrained(local_path)
return cls(model, tokenizer)

Expand Down Expand Up @@ -209,8 +210,8 @@ def load(self):
name=self.name, download_path=self.cache_path, store=self.store
)
model_path = dm.fs.join(download_output_dir, self.store.MODEL_PATH_NAME)
model = HFExperiment.load(model_path)
return model
self._model = HFExperiment.load(model_path)
return self._model


class PretrainedHFTransformer(PretrainedMolTransformer):
Expand Down Expand Up @@ -291,6 +292,7 @@ def __init__(
self.preload = preload
self.pooling = pooling
self.prefer_encoder = prefer_encoder
self.device = torch.device(device)
self._pooling_obj = None
if isinstance(kind, HFModel):
self.kind = kind.name
Expand Down Expand Up @@ -322,6 +324,7 @@ def _update_params(self):
def _preload(self):
"""Perform preloading of the model from the store"""
super()._preload()
self.featurizer.model.to(self.device)
self.featurizer.max_length = self.max_length

# we can be confident that the model has been loaded here
Expand Down Expand Up @@ -408,9 +411,12 @@ def _embed(self, inputs, **kwargs):
"""
self._preload()

# Move inputs to the correct device
inputs = {key: value.to(self.device) for key, value in inputs.items()}

attention_mask = inputs.get("attention_mask", None)
if attention_mask is not None and self.ignore_padding:
attention_mask = attention_mask.unsqueeze(-1) # B, S, 1
attention_mask = attention_mask.unsqueeze(-1).to(self.device) # B, S, 1
else:
attention_mask = None
with torch.no_grad():
Expand All @@ -424,7 +430,7 @@ def _embed(self, inputs, **kwargs):
hidden_state = out_dict["hidden_states"]
emb_layers = []
for layer in self.concat_layers:
emb = hidden_state[layer].detach().cpu() # B, S, D
emb = hidden_state[layer].detach() # B, S, D
emb = self._pooling_obj(
emb,
inputs["input_ids"],
Expand All @@ -433,7 +439,7 @@ def _embed(self, inputs, **kwargs):
)
emb_layers.append(emb)
emb = torch.cat(emb_layers, dim=1)
return emb.numpy()
return emb.cpu().numpy() # Move the final tensor to CPU before converting to numpy array

def set_max_length(self, max_length: int):
"""Set the maximum length for this featurizer"""
Expand Down

0 comments on commit c4eb5df

Please sign in to comment.