Skip to content

Commit

Permalink
save
Browse files Browse the repository at this point in the history
  • Loading branch information
nsthorat committed Feb 14, 2024
1 parent fe03f3f commit 1a4c75a
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 3 deletions.
2 changes: 1 addition & 1 deletion lilac/data/cluster_titling.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_random_exponential

from ..batch_utils import group_by_sorted_key_iter
from ..env import env
from ..schema import (
Item,
)
Expand All @@ -24,7 +25,6 @@
)
from ..tasks import TaskInfo
from ..utils import chunks, log
from ..env import env

_TOP_K_CENTRAL_DOCS = 7
_TOP_K_CENTRAL_TITLES = 20
Expand Down
3 changes: 2 additions & 1 deletion lilac/embeddings/bge.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import gc
from typing import TYPE_CHECKING, ClassVar, Optional

import numpy as np
from typing_extensions import override

from ..utils import log
Expand Down Expand Up @@ -67,7 +68,7 @@ def setup(self) -> None:
def compute(self, docs: list[str]) -> list[Optional[Item]]:
"""Call the embedding function."""

def _encode(doc: list[str]):
def _encode(doc: list[str]) -> list[np.ndarray]:
# Extract the dense vectors from the model.
return self._model.encode(doc)['dense_vecs']

Expand Down
3 changes: 2 additions & 1 deletion lilac/embeddings/nomic_embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import gc
from typing import TYPE_CHECKING, ClassVar, Optional

import numpy as np
from typing_extensions import override

if TYPE_CHECKING:
Expand Down Expand Up @@ -65,7 +66,7 @@ def compute(self, docs: list[str]) -> list[Optional[Item]]:
'Please install it with `pip install "sentence_transformers".'
)

def _encode(doc: list[str]):
def _encode(doc: list[str]) -> list[np.ndarray]:
embeddings = self._model.encode(doc, convert_to_tensor=True)
# Extract the dense vectors from the model.
embeddings = F.layer_norm(embeddings, normalized_shape=(embeddings.shape[1],))
Expand Down
5 changes: 5 additions & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -159,3 +159,8 @@ follow_imports = skip
[mypy-hdbscan.*]
ignore_missing_imports = True
follow_imports = skip


[mypy-FlagEmbedding.*]
ignore_missing_imports = True
follow_imports = skip

0 comments on commit 1a4c75a

Please sign in to comment.