Skip to content

Commit

Permalink
feat!: Rename model_name_or_path to model in ExtractiveReader (#…
Browse files Browse the repository at this point in the history
…6736)

* rename model parameter and internam model attribute in ExtractiveReader

* fix tests for ExtractiveReader

* fix e2e

* reno

* another fix

* review feedback

* Update releasenotes/notes/rename-model-param-reader-b8cbb0d638e3b8c2.yaml
  • Loading branch information
ZanSara authored Jan 15, 2024
1 parent b236ea4 commit 96c0b59
Show file tree
Hide file tree
Showing 5 changed files with 14 additions and 11 deletions.
2 changes: 1 addition & 1 deletion e2e/pipelines/test_eval_extractive_qa_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def test_extractive_qa_pipeline(tmp_path):
# Create the pipeline
qa_pipeline = Pipeline()
qa_pipeline.add_component(instance=InMemoryBM25Retriever(document_store=InMemoryDocumentStore()), name="retriever")
qa_pipeline.add_component(instance=ExtractiveReader(model_name_or_path="deepset/tinyroberta-squad2"), name="reader")
qa_pipeline.add_component(instance=ExtractiveReader(model="deepset/tinyroberta-squad2"), name="reader")
qa_pipeline.connect("retriever", "reader")

# Populate the document store
Expand Down
2 changes: 1 addition & 1 deletion e2e/pipelines/test_extractive_qa_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ def test_extractive_qa_pipeline(tmp_path):
# Create the pipeline
qa_pipeline = Pipeline()
qa_pipeline.add_component(instance=InMemoryBM25Retriever(document_store=InMemoryDocumentStore()), name="retriever")
qa_pipeline.add_component(instance=ExtractiveReader(model_name_or_path="deepset/tinyroberta-squad2"), name="reader")
qa_pipeline.add_component(instance=ExtractiveReader(model="deepset/tinyroberta-squad2"), name="reader")
qa_pipeline.connect("retriever", "reader")

# Draw the pipeline
Expand Down
10 changes: 5 additions & 5 deletions haystack/components/readers/extractive.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class ExtractiveReader:

def __init__(
self,
model_name_or_path: Union[Path, str] = "deepset/roberta-base-squad2-distilled",
model: Union[Path, str] = "deepset/roberta-base-squad2-distilled",
device: Optional[str] = None,
token: Union[bool, str, None] = None,
top_k: int = 20,
Expand All @@ -54,7 +54,7 @@ def __init__(
) -> None:
"""
Creates an ExtractiveReader
:param model_name_or_path: A Hugging Face transformers question answering model.
:param model: A Hugging Face transformers question answering model.
Can either be a path to a folder containing the model files or an identifier for the Hugging Face hub.
Default: `'deepset/roberta-base-squad2-distilled'`
:param device: Pytorch device string. Uses GPU by default, if available.
Expand Down Expand Up @@ -83,11 +83,11 @@ def __init__(
both of these answers could be kept if this variable is set to 0.24 or lower.
If None is provided then all answers are kept.
:param model_kwargs: Additional keyword arguments passed to `AutoModelForQuestionAnswering.from_pretrained`
when loading the model specified in `model_name_or_path`. For details on what kwargs you can pass,
when loading the model specified in `model`. For details on what kwargs you can pass,
see the model's documentation.
"""
torch_and_transformers_import.check()
self.model_name_or_path = str(model_name_or_path)
self.model_name_or_path = str(model)
self.model = None
self.device = device
self.token = token
Expand All @@ -114,7 +114,7 @@ def to_dict(self) -> Dict[str, Any]:
"""
serialization_dict = default_to_dict(
self,
model_name_or_path=self.model_name_or_path,
model=self.model_name_or_path,
device=self.device,
token=self.token if not isinstance(self.token, str) else None,
max_seq_length=self.max_seq_length,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
---
upgrade:
- Rename parameter `model_name_or_path` to `model` in `ExtractiveReader`.
8 changes: 4 additions & 4 deletions test/components/readers/test_extractive.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def forward(self, input_ids, attention_mask, *args, **kwargs):

with patch("haystack.components.readers.extractive.AutoModelForQuestionAnswering.from_pretrained") as model:
model.return_value = MockModel()
reader = ExtractiveReader(model_name_or_path="mock-model", device="cpu:0")
reader = ExtractiveReader(model="mock-model", device="cpu:0")
reader.warm_up()
return reader

Expand All @@ -94,7 +94,7 @@ def test_to_dict():
assert data == {
"type": "haystack.components.readers.extractive.ExtractiveReader",
"init_parameters": {
"model_name_or_path": "my-model",
"model": "my-model",
"device": None,
"token": None, # don't serialize valid tokens
"top_k": 20,
Expand All @@ -117,7 +117,7 @@ def test_to_dict_empty_model_kwargs():
assert data == {
"type": "haystack.components.readers.extractive.ExtractiveReader",
"init_parameters": {
"model_name_or_path": "my-model",
"model": "my-model",
"device": None,
"token": None, # don't serialize valid tokens
"top_k": 20,
Expand All @@ -137,7 +137,7 @@ def test_from_dict():
data = {
"type": "haystack.components.readers.extractive.ExtractiveReader",
"init_parameters": {
"model_name_or_path": "my-model",
"model": "my-model",
"device": None,
"token": None,
"top_k": 20,
Expand Down

0 comments on commit 96c0b59

Please sign in to comment.