Skip to content

Commit

Permalink
feat: rename model_name or model_name_or_path to model in gener…
Browse files Browse the repository at this point in the history
…ators (#6715)

* renamed model_name or model_name_or_path to model

* added release notes

* Update releasenotes/notes/renamed-model_name-or-model_name_or_path-to-model-184490cbb66c4d7c.yaml

---------

Co-authored-by: ZanSara <[email protected]>
  • Loading branch information
sahusiddharth and ZanSara authored Jan 12, 2024
1 parent 80c3e68 commit dbdeb82
Show file tree
Hide file tree
Showing 10 changed files with 66 additions and 75 deletions.
2 changes: 1 addition & 1 deletion haystack/components/generators/azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def __init__(
self.azure_endpoint = azure_endpoint
self.azure_deployment = azure_deployment
self.organization = organization
self.model_name: str = azure_deployment or "gpt-35-turbo"
self.model: str = azure_deployment or "gpt-35-turbo"

self.client = AzureOpenAI(
api_version=api_version,
Expand Down
2 changes: 1 addition & 1 deletion haystack/components/generators/chat/azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def __init__(
self.azure_endpoint = azure_endpoint
self.azure_deployment = azure_deployment
self.organization = organization
self.model_name = azure_deployment or "gpt-35-turbo"
self.model = azure_deployment or "gpt-35-turbo"

self.client = AzureOpenAI(
api_version=api_version,
Expand Down
18 changes: 9 additions & 9 deletions haystack/components/generators/chat/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,19 +63,19 @@ class OpenAIChatGenerator:
def __init__(
self,
api_key: Optional[str] = None,
model_name: str = "gpt-3.5-turbo",
model: str = "gpt-3.5-turbo",
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
api_base_url: Optional[str] = None,
organization: Optional[str] = None,
generation_kwargs: Optional[Dict[str, Any]] = None,
):
"""
Creates an instance of OpenAIChatGenerator. Unless specified otherwise in the `model_name`, this is for OpenAI's
Creates an instance of OpenAIChatGenerator. Unless specified otherwise in the `model`, this is for OpenAI's
GPT-3.5 model.
:param api_key: The OpenAI API key. It can be explicitly provided or automatically read from the
environment variable OPENAI_API_KEY (recommended).
:param model_name: The name of the model to use.
:param model: The name of the model to use.
:param streaming_callback: A callback function that is called when a new token is received from the stream.
The callback function accepts StreamingChunk as an argument.
:param api_base_url: An optional base URL.
Expand All @@ -101,7 +101,7 @@ def __init__(
- `logit_bias`: Add a logit bias to specific tokens. The keys of the dictionary are tokens, and the
values are the bias to add to that token.
"""
self.model_name = model_name
self.model = model
self.generation_kwargs = generation_kwargs or {}
self.streaming_callback = streaming_callback
self.api_base_url = api_base_url
Expand All @@ -112,7 +112,7 @@ def _get_telemetry_data(self) -> Dict[str, Any]:
"""
Data that is sent to Posthog for usage analytics.
"""
return {"model": self.model_name}
return {"model": self.model}

def to_dict(self) -> Dict[str, Any]:
"""
Expand All @@ -122,7 +122,7 @@ def to_dict(self) -> Dict[str, Any]:
callback_name = serialize_callback_handler(self.streaming_callback) if self.streaming_callback else None
return default_to_dict(
self,
model_name=self.model_name,
model=self.model,
streaming_callback=callback_name,
api_base_url=self.api_base_url,
organization=self.organization,
Expand Down Expand Up @@ -162,7 +162,7 @@ def run(self, messages: List[ChatMessage], generation_kwargs: Optional[Dict[str,
openai_formatted_messages = self._convert_to_openai_format(messages)

chat_completion: Union[Stream[ChatCompletionChunk], ChatCompletion] = self.client.chat.completions.create(
model=self.model_name,
model=self.model,
messages=openai_formatted_messages, # type: ignore # openai expects list of specific message types
stream=self.streaming_callback is not None,
**generation_kwargs,
Expand Down Expand Up @@ -335,7 +335,7 @@ class GPTChatGenerator(OpenAIChatGenerator):
def __init__(
self,
api_key: Optional[str] = None,
model_name: str = "gpt-3.5-turbo",
model: str = "gpt-3.5-turbo",
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
api_base_url: Optional[str] = None,
organization: Optional[str] = None,
Expand All @@ -349,7 +349,7 @@ def __init__(
)
super().__init__(
api_key=api_key,
model_name=model_name,
model=model,
streaming_callback=streaming_callback,
api_base_url=api_base_url,
organization=organization,
Expand Down
10 changes: 5 additions & 5 deletions haystack/components/generators/hugging_face_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ class HuggingFaceLocalGenerator:
```python
from haystack.components.generators import HuggingFaceLocalGenerator
generator = HuggingFaceLocalGenerator(model_name_or_path="google/flan-t5-large",
generator = HuggingFaceLocalGenerator(model="google/flan-t5-large",
task="text2text-generation",
generation_kwargs={
"max_new_tokens": 100,
Expand All @@ -80,7 +80,7 @@ class HuggingFaceLocalGenerator:

def __init__(
self,
model_name_or_path: str = "google/flan-t5-base",
model: str = "google/flan-t5-base",
task: Optional[Literal["text-generation", "text2text-generation"]] = None,
device: Optional[str] = None,
token: Optional[Union[str, bool]] = None,
Expand All @@ -89,7 +89,7 @@ def __init__(
stop_words: Optional[List[str]] = None,
):
"""
:param model_name_or_path: The name or path of a Hugging Face model for text generation,
:param model: The name or path of a Hugging Face model for text generation,
for example, "google/flan-t5-large".
If the model is also specified in the `huggingface_pipeline_kwargs`, this parameter will be ignored.
:param task: The task for the Hugging Face pipeline.
Expand All @@ -113,7 +113,7 @@ def __init__(
:param huggingface_pipeline_kwargs: Dictionary containing keyword arguments used to initialize the
Hugging Face pipeline for text generation.
These keyword arguments provide fine-grained control over the Hugging Face pipeline.
In case of duplication, these kwargs override `model_name_or_path`, `task`, `device`, and `token` init parameters.
In case of duplication, these kwargs override `model`, `task`, `device`, and `token` init parameters.
See Hugging Face's [documentation](https://huggingface.co/docs/transformers/en/main_classes/pipelines#transformers.pipeline.task)
for more information on the available kwargs.
In this dictionary, you can also include `model_kwargs` to specify the kwargs
Expand All @@ -131,7 +131,7 @@ def __init__(

# check if the huggingface_pipeline_kwargs contain the essential parameters
# otherwise, populate them with values from other init parameters
huggingface_pipeline_kwargs.setdefault("model", model_name_or_path)
huggingface_pipeline_kwargs.setdefault("model", model)
huggingface_pipeline_kwargs.setdefault("token", token)
if (
device is not None
Expand Down
18 changes: 9 additions & 9 deletions haystack/components/generators/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,20 +51,20 @@ class OpenAIGenerator:
def __init__(
self,
api_key: Optional[str] = None,
model_name: str = "gpt-3.5-turbo",
model: str = "gpt-3.5-turbo",
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
api_base_url: Optional[str] = None,
organization: Optional[str] = None,
system_prompt: Optional[str] = None,
generation_kwargs: Optional[Dict[str, Any]] = None,
):
"""
Creates an instance of OpenAIGenerator. Unless specified otherwise in the `model_name`, this is for OpenAI's
Creates an instance of OpenAIGenerator. Unless specified otherwise in the `model`, this is for OpenAI's
GPT-3.5 model.
:param api_key: The OpenAI API key. It can be explicitly provided or automatically read from the
environment variable OPENAI_API_KEY (recommended).
:param model_name: The name of the model to use.
:param model: The name of the model to use.
:param streaming_callback: A callback function that is called when a new token is received from the stream.
The callback function accepts StreamingChunk as an argument.
:param api_base_url: An optional base URL.
Expand Down Expand Up @@ -92,7 +92,7 @@ def __init__(
- `logit_bias`: Add a logit bias to specific tokens. The keys of the dictionary are tokens, and the
values are the bias to add to that token.
"""
self.model_name = model_name
self.model = model
self.generation_kwargs = generation_kwargs or {}
self.system_prompt = system_prompt
self.streaming_callback = streaming_callback
Expand All @@ -105,7 +105,7 @@ def _get_telemetry_data(self) -> Dict[str, Any]:
"""
Data that is sent to Posthog for usage analytics.
"""
return {"model": self.model_name}
return {"model": self.model}

def to_dict(self) -> Dict[str, Any]:
"""
Expand All @@ -115,7 +115,7 @@ def to_dict(self) -> Dict[str, Any]:
callback_name = serialize_callback_handler(self.streaming_callback) if self.streaming_callback else None
return default_to_dict(
self,
model_name=self.model_name,
model=self.model,
streaming_callback=callback_name,
api_base_url=self.api_base_url,
generation_kwargs=self.generation_kwargs,
Expand Down Expand Up @@ -161,7 +161,7 @@ def run(self, prompt: str, generation_kwargs: Optional[Dict[str, Any]] = None):
openai_formatted_messages = self._convert_to_openai_format(messages)

completion: Union[Stream[ChatCompletionChunk], ChatCompletion] = self.client.chat.completions.create(
model=self.model_name,
model=self.model,
messages=openai_formatted_messages, # type: ignore
stream=self.streaming_callback is not None,
**generation_kwargs,
Expand Down Expand Up @@ -280,7 +280,7 @@ class GPTGenerator(OpenAIGenerator):
def __init__(
self,
api_key: Optional[str] = None,
model_name: str = "gpt-3.5-turbo",
model: str = "gpt-3.5-turbo",
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
api_base_url: Optional[str] = None,
organization: Optional[str] = None,
Expand All @@ -295,7 +295,7 @@ def __init__(
)
super().__init__(
api_key=api_key,
model_name=model_name,
model=model,
streaming_callback=streaming_callback,
api_base_url=api_base_url,
organization=organization,
Expand Down
2 changes: 1 addition & 1 deletion haystack/pipeline_utils/rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ class _OpenAIResolved(_GeneratorResolver):
def resolve(self, model_key: str, api_key: str) -> Any:
# does the model_key match the pattern OpenAI GPT pattern?
if re.match(r"^gpt-4-.*", model_key) or re.match(r"^gpt-3.5-.*", model_key):
return OpenAIGenerator(model_name=model_key, api_key=api_key)
return OpenAIGenerator(model=model_key, api_key=api_key)
return None


Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
---
upgrade:
- Rename the generator parameters `model_name` and `model_name_or_path` to `model`. This change affects all Generator classes.
26 changes: 12 additions & 14 deletions test/components/generators/chat/test_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class TestOpenAIChatGenerator:
def test_init_default(self):
component = OpenAIChatGenerator(api_key="test-api-key")
assert component.client.api_key == "test-api-key"
assert component.model_name == "gpt-3.5-turbo"
assert component.model == "gpt-3.5-turbo"
assert component.streaming_callback is None
assert not component.generation_kwargs

Expand All @@ -32,13 +32,13 @@ def test_init_fail_wo_api_key(self, monkeypatch):
def test_init_with_parameters(self):
component = OpenAIChatGenerator(
api_key="test-api-key",
model_name="gpt-4",
model="gpt-4",
streaming_callback=default_streaming_callback,
api_base_url="test-base-url",
generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"},
)
assert component.client.api_key == "test-api-key"
assert component.model_name == "gpt-4"
assert component.model == "gpt-4"
assert component.streaming_callback is default_streaming_callback
assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"}

Expand All @@ -48,7 +48,7 @@ def test_to_dict_default(self):
assert data == {
"type": "haystack.components.generators.chat.openai.OpenAIChatGenerator",
"init_parameters": {
"model_name": "gpt-3.5-turbo",
"model": "gpt-3.5-turbo",
"organization": None,
"streaming_callback": None,
"api_base_url": None,
Expand All @@ -59,7 +59,7 @@ def test_to_dict_default(self):
def test_to_dict_with_parameters(self):
component = OpenAIChatGenerator(
api_key="test-api-key",
model_name="gpt-4",
model="gpt-4",
streaming_callback=default_streaming_callback,
api_base_url="test-base-url",
generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"},
Expand All @@ -68,7 +68,7 @@ def test_to_dict_with_parameters(self):
assert data == {
"type": "haystack.components.generators.chat.openai.OpenAIChatGenerator",
"init_parameters": {
"model_name": "gpt-4",
"model": "gpt-4",
"organization": None,
"api_base_url": "test-base-url",
"streaming_callback": "haystack.components.generators.utils.default_streaming_callback",
Expand All @@ -79,7 +79,7 @@ def test_to_dict_with_parameters(self):
def test_to_dict_with_lambda_streaming_callback(self):
component = OpenAIChatGenerator(
api_key="test-api-key",
model_name="gpt-4",
model="gpt-4",
streaming_callback=lambda x: x,
api_base_url="test-base-url",
generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"},
Expand All @@ -88,7 +88,7 @@ def test_to_dict_with_lambda_streaming_callback(self):
assert data == {
"type": "haystack.components.generators.chat.openai.OpenAIChatGenerator",
"init_parameters": {
"model_name": "gpt-4",
"model": "gpt-4",
"organization": None,
"api_base_url": "test-base-url",
"streaming_callback": "chat.test_openai.<lambda>",
Expand All @@ -100,14 +100,14 @@ def test_from_dict(self):
data = {
"type": "haystack.components.generators.chat.openai.OpenAIChatGenerator",
"init_parameters": {
"model_name": "gpt-4",
"model": "gpt-4",
"api_base_url": "test-base-url",
"streaming_callback": "haystack.components.generators.utils.default_streaming_callback",
"generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"},
},
}
component = OpenAIChatGenerator.from_dict(data)
assert component.model_name == "gpt-4"
assert component.model == "gpt-4"
assert component.streaming_callback is default_streaming_callback
assert component.api_base_url == "test-base-url"
assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"}
Expand All @@ -117,7 +117,7 @@ def test_from_dict_fail_wo_env_var(self, monkeypatch):
data = {
"type": "haystack.components.generators.chat.openai.OpenAIChatGenerator",
"init_parameters": {
"model_name": "gpt-4",
"model": "gpt-4",
"organization": None,
"api_base_url": "test-base-url",
"streaming_callback": "haystack.components.generators.utils.default_streaming_callback",
Expand Down Expand Up @@ -222,9 +222,7 @@ def test_live_run(self):
)
@pytest.mark.integration
def test_live_run_wrong_model(self, chat_messages):
component = OpenAIChatGenerator(
model_name="something-obviously-wrong", api_key=os.environ.get("OPENAI_API_KEY")
)
component = OpenAIChatGenerator(model="something-obviously-wrong", api_key=os.environ.get("OPENAI_API_KEY"))
with pytest.raises(OpenAIError):
component.run(chat_messages)

Expand Down
Loading

0 comments on commit dbdeb82

Please sign in to comment.