Skip to content

Commit

Permalink
Corrections
Browse files Browse the repository at this point in the history
  • Loading branch information
metric-space committed Oct 3, 2023
1 parent 0de926c commit 101bcf4
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions dalm/models/retriever_only_base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ def __init__(
def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
if self.is_autoregressive:
# we take the last hidden state of the model
token_embeddings = self.model.sample(
input_ids, attention_mask, output_hidden_states=True, return_dict_in_generate=True
token_embeddings = self.model(
input_ids, attention_mask=attention_mask, output_hidden_states=True, return_dict=True
).hidden_states[-1]
else:
# First element of model_output contains all token embeddings
Expand Down

0 comments on commit 101bcf4

Please sign in to comment.