Skip to content

Commit 5b37dd2

Browse files
authored
fix: add cross_encoder_kwargs parameter for advanced configuration (run-llama#19148)
1 parent 48b544b commit 5b37dd2

File tree

2 files changed

+28
-6
lines changed
  • llama-index-integrations/postprocessor/llama-index-postprocessor-sbert-rerank

2 files changed

+28
-6
lines changed

llama-index-integrations/postprocessor/llama-index-postprocessor-sbert-rerank/llama_index/postprocessor/sbert_rerank/base.py

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,11 @@ class SentenceTransformerRerank(BaseNodePostprocessor):
2828
default=False,
2929
description="Whether to keep the retrieval score in metadata.",
3030
)
31+
cross_encoder_kwargs: dict = Field(
32+
default_factory=dict,
33+
description="Additional keyword arguments for CrossEncoder initialization. "
34+
"device and model should not be included here.",
35+
)
3136
_model: Any = PrivateAttr()
3237

3338
def __init__(
@@ -37,6 +42,7 @@ def __init__(
3742
device: Optional[str] = None,
3843
keep_retrieval_score: Optional[bool] = False,
3944
cache_dir: Optional[Union[str, Path]] = None,
45+
cross_encoder_kwargs: Optional[dict] = None,
4046
):
4147
try:
4248
from sentence_transformers import CrossEncoder
@@ -51,13 +57,29 @@ def __init__(
5157
model=model,
5258
device=device,
5359
keep_retrieval_score=keep_retrieval_score,
60+
cross_encoder_kwargs=cross_encoder_kwargs or {},
5461
)
55-
device = infer_torch_device() if device is None else device
62+
63+
init_kwargs = self.cross_encoder_kwargs.copy()
64+
if "device" in init_kwargs or "model" in init_kwargs:
65+
raise ValueError(
66+
"'device' and 'model' should not be specified in 'cross_encoder_kwargs'. "
67+
"Use the top-level 'device' and 'model' parameters instead."
68+
)
69+
70+
# Set default max_length if not provided by the user in kwargs.
71+
if "max_length" not in init_kwargs:
72+
init_kwargs["max_length"] = DEFAULT_SENTENCE_TRANSFORMER_MAX_LENGTH
73+
74+
# Explicit arguments from the constructor take precedence over kwargs
75+
resolved_device = infer_torch_device() if device is None else device
76+
init_kwargs["device"] = resolved_device
77+
if cache_dir:
78+
init_kwargs["cache_dir"] = cache_dir
79+
5680
self._model = CrossEncoder(
57-
model,
58-
max_length=DEFAULT_SENTENCE_TRANSFORMER_MAX_LENGTH,
59-
device=device,
60-
cache_dir=cache_dir,
81+
model_name=model,
82+
**init_kwargs,
6183
)
6284

6385
@classmethod

llama-index-integrations/postprocessor/llama-index-postprocessor-sbert-rerank/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ dev = [
2626

2727
[project]
2828
name = "llama-index-postprocessor-sbert-rerank"
29-
version = "0.3.1"
29+
version = "0.3.2"
3030
description = "llama-index postprocessor sbert rerank integration"
3131
authors = [{name = "Your Name", email = "[email protected]"}]
3232
requires-python = ">=3.9,<4.0"

0 commit comments

Comments
 (0)