|
1 | 1 | import logging
|
2 |
| -from typing import Optional, List, Dict, Union, Any |
| 2 | +from typing import Any, Dict, List, Optional, Union |
3 | 3 |
|
4 | 4 | from haystack.nodes.prompt.invocation_layer.handlers import DefaultTokenStreamingHandler, TokenStreamingHandler
|
5 | 5 | from haystack.nodes.prompt.invocation_layer.open_ai import OpenAIInvocationLayer
|
6 | 6 | from haystack.nodes.prompt.invocation_layer.utils import has_azure_parameters
|
7 |
| -from haystack.utils.openai_utils import openai_request, _check_openai_finish_reason, count_openai_tokens_messages |
| 7 | +from haystack.utils.openai_utils import ( |
| 8 | + _check_openai_finish_reason, |
| 9 | + check_openai_async_policy_violation, |
| 10 | + check_openai_policy_violation, |
| 11 | + count_openai_tokens_messages, |
| 12 | + openai_async_request, |
| 13 | + openai_request, |
| 14 | +) |
8 | 15 |
|
9 | 16 | logger = logging.getLogger(__name__)
|
10 | 17 |
|
@@ -43,45 +50,6 @@ def __init__(
|
43 | 50 | """
|
44 | 51 | super().__init__(api_key, model_name_or_path, max_length, api_base=api_base, **kwargs)
|
45 | 52 |
|
46 |
| - def _execute_openai_request( |
47 |
| - self, prompt: Union[str, List[Dict]], base_payload: Dict, kwargs_with_defaults: Dict, stream: bool |
48 |
| - ): |
49 |
| - """ |
50 |
| - For more details, see [OpenAI ChatGPT API reference](https://platform.openai.com/docs/api-reference/chat). |
51 |
| - """ |
52 |
| - if isinstance(prompt, str): |
53 |
| - messages = [{"role": "user", "content": prompt}] |
54 |
| - elif isinstance(prompt, list) and len(prompt) > 0 and isinstance(prompt[0], dict): |
55 |
| - messages = prompt |
56 |
| - else: |
57 |
| - raise ValueError( |
58 |
| - f"The prompt format is different than what the model expects. " |
59 |
| - f"The model {self.model_name_or_path} requires either a string or messages in the ChatML format. " |
60 |
| - f"For more details, see this [GitHub discussion](https://github.com/openai/openai-python/blob/main/chatml.md)." |
61 |
| - ) |
62 |
| - extra_payload = {"messages": messages} |
63 |
| - payload = {**base_payload, **extra_payload} |
64 |
| - if not stream: |
65 |
| - response = openai_request(url=self.url, headers=self.headers, payload=payload) |
66 |
| - _check_openai_finish_reason(result=response, payload=payload) |
67 |
| - assistant_response = [choice["message"]["content"].strip() for choice in response["choices"]] |
68 |
| - else: |
69 |
| - response = openai_request( |
70 |
| - url=self.url, headers=self.headers, payload=payload, read_response=False, stream=True |
71 |
| - ) |
72 |
| - handler: TokenStreamingHandler = kwargs_with_defaults.pop("stream_handler", DefaultTokenStreamingHandler()) |
73 |
| - assistant_response = self._process_streaming_response(response=response, stream_handler=handler) |
74 |
| - |
75 |
| - # Although ChatGPT generates text until stop words are encountered, unfortunately it includes the stop word |
76 |
| - # We want to exclude it to be consistent with other invocation layers |
77 |
| - if "stop" in kwargs_with_defaults and kwargs_with_defaults["stop"] is not None: |
78 |
| - stop_words = kwargs_with_defaults["stop"] |
79 |
| - for idx, _ in enumerate(assistant_response): |
80 |
| - for stop_word in stop_words: |
81 |
| - assistant_response[idx] = assistant_response[idx].replace(stop_word, "").strip() |
82 |
| - |
83 |
| - return assistant_response |
84 |
| - |
85 | 53 | def _extract_token(self, event_data: Dict[str, Any]):
|
86 | 54 | delta = event_data["choices"][0]["delta"]
|
87 | 55 | if "content" in delta:
|
@@ -141,3 +109,109 @@ def supports(cls, model_name_or_path: str, **kwargs) -> bool:
|
141 | 109 | and not "gpt-3.5-turbo-instruct" in model_name_or_path
|
142 | 110 | )
|
143 | 111 | return valid_model and not has_azure_parameters(**kwargs)
|
| 112 | + |
| 113 | + async def ainvoke(self, *args, **kwargs): |
| 114 | + """ |
| 115 | + Invokes a prompt on the model. Based on the model, it takes in a prompt (or either a prompt or a list of messages) |
| 116 | + and returns a list of responses using a REST invocation. |
| 117 | +
|
| 118 | + :return: The responses are being returned. |
| 119 | +
|
| 120 | + Note: Only kwargs relevant to OpenAI are passed to OpenAI rest API. Others kwargs are ignored. |
| 121 | + For more details, see OpenAI [documentation](https://platform.openai.com/docs/api-reference/completions/create). |
| 122 | + """ |
| 123 | + prompt, base_payload, kwargs_with_defaults, stream, moderation = self._prepare_invoke(*args, **kwargs) |
| 124 | + |
| 125 | + if moderation and await check_openai_async_policy_violation(input=prompt, headers=self.headers): |
| 126 | + logger.info("Prompt '%s' will not be sent to OpenAI due to potential policy violation.", prompt) |
| 127 | + return [] |
| 128 | + |
| 129 | + if isinstance(prompt, str): |
| 130 | + messages = [{"role": "user", "content": prompt}] |
| 131 | + elif isinstance(prompt, list) and len(prompt) > 0 and isinstance(prompt[0], dict): |
| 132 | + messages = prompt |
| 133 | + else: |
| 134 | + raise ValueError( |
| 135 | + f"The prompt format is different than what the model expects. " |
| 136 | + f"The model {self.model_name_or_path} requires either a string or messages in the ChatML format. " |
| 137 | + f"For more details, see this [GitHub discussion](https://github.com/openai/openai-python/blob/main/chatml.md)." |
| 138 | + ) |
| 139 | + extra_payload = {"messages": messages} |
| 140 | + payload = {**base_payload, **extra_payload} |
| 141 | + if not stream: |
| 142 | + response = await openai_async_request(url=self.url, headers=self.headers, payload=payload) |
| 143 | + _check_openai_finish_reason(result=response, payload=payload) |
| 144 | + assistant_response = [choice["message"]["content"].strip() for choice in response["choices"]] |
| 145 | + else: |
| 146 | + response = await openai_async_request( |
| 147 | + url=self.url, headers=self.headers, payload=payload, read_response=False, stream=True |
| 148 | + ) |
| 149 | + handler: TokenStreamingHandler = kwargs_with_defaults.pop("stream_handler", DefaultTokenStreamingHandler()) |
| 150 | + assistant_response = self._process_streaming_response(response=response, stream_handler=handler) |
| 151 | + |
| 152 | + # Although ChatGPT generates text until stop words are encountered, unfortunately it includes the stop word |
| 153 | + # We want to exclude it to be consistent with other invocation layers |
| 154 | + if "stop" in kwargs_with_defaults and kwargs_with_defaults["stop"] is not None: |
| 155 | + stop_words = kwargs_with_defaults["stop"] |
| 156 | + for idx, _ in enumerate(assistant_response): |
| 157 | + for stop_word in stop_words: |
| 158 | + assistant_response[idx] = assistant_response[idx].replace(stop_word, "").strip() |
| 159 | + |
| 160 | + if moderation and await check_openai_async_policy_violation(input=assistant_response, headers=self.headers): |
| 161 | + logger.info("Response '%s' will not be returned due to potential policy violation.", assistant_response) |
| 162 | + return [] |
| 163 | + |
| 164 | + return assistant_response |
| 165 | + |
| 166 | + def invoke(self, *args, **kwargs): |
| 167 | + """ |
| 168 | + Invokes a prompt on the model. Based on the model, it takes in a prompt (or either a prompt or a list of messages) |
| 169 | + and returns a list of responses using a REST invocation. |
| 170 | +
|
| 171 | + :return: The responses are being returned. |
| 172 | +
|
| 173 | + Note: Only kwargs relevant to OpenAI are passed to OpenAI rest API. Others kwargs are ignored. |
| 174 | + For more details, see OpenAI [documentation](https://platform.openai.com/docs/api-reference/completions/create). |
| 175 | + """ |
| 176 | + prompt, base_payload, kwargs_with_defaults, stream, moderation = self._prepare_invoke(*args, **kwargs) |
| 177 | + |
| 178 | + if moderation and check_openai_policy_violation(input=prompt, headers=self.headers): |
| 179 | + logger.info("Prompt '%s' will not be sent to OpenAI due to potential policy violation.", prompt) |
| 180 | + return [] |
| 181 | + |
| 182 | + if isinstance(prompt, str): |
| 183 | + messages = [{"role": "user", "content": prompt}] |
| 184 | + elif isinstance(prompt, list) and len(prompt) > 0 and isinstance(prompt[0], dict): |
| 185 | + messages = prompt |
| 186 | + else: |
| 187 | + raise ValueError( |
| 188 | + f"The prompt format is different than what the model expects. " |
| 189 | + f"The model {self.model_name_or_path} requires either a string or messages in the ChatML format. " |
| 190 | + f"For more details, see this [GitHub discussion](https://github.com/openai/openai-python/blob/main/chatml.md)." |
| 191 | + ) |
| 192 | + extra_payload = {"messages": messages} |
| 193 | + payload = {**base_payload, **extra_payload} |
| 194 | + if not stream: |
| 195 | + response = openai_request(url=self.url, headers=self.headers, payload=payload) |
| 196 | + _check_openai_finish_reason(result=response, payload=payload) |
| 197 | + assistant_response = [choice["message"]["content"].strip() for choice in response["choices"]] |
| 198 | + else: |
| 199 | + response = openai_request( |
| 200 | + url=self.url, headers=self.headers, payload=payload, read_response=False, stream=True |
| 201 | + ) |
| 202 | + handler: TokenStreamingHandler = kwargs_with_defaults.pop("stream_handler", DefaultTokenStreamingHandler()) |
| 203 | + assistant_response = self._process_streaming_response(response=response, stream_handler=handler) |
| 204 | + |
| 205 | + # Although ChatGPT generates text until stop words are encountered, unfortunately it includes the stop word |
| 206 | + # We want to exclude it to be consistent with other invocation layers |
| 207 | + if "stop" in kwargs_with_defaults and kwargs_with_defaults["stop"] is not None: |
| 208 | + stop_words = kwargs_with_defaults["stop"] |
| 209 | + for idx, _ in enumerate(assistant_response): |
| 210 | + for stop_word in stop_words: |
| 211 | + assistant_response[idx] = assistant_response[idx].replace(stop_word, "").strip() |
| 212 | + |
| 213 | + if moderation and check_openai_policy_violation(input=assistant_response, headers=self.headers): |
| 214 | + logger.info("Response '%s' will not be returned due to potential policy violation.", assistant_response) |
| 215 | + return [] |
| 216 | + |
| 217 | + return assistant_response |
0 commit comments