|
1 | 1 | """Test Chat model for OCI Data Science Model Deployment Endpoint.""" |
2 | 2 |
|
3 | | -import sys |
4 | | -from typing import Any, AsyncGenerator, Dict, Generator |
5 | 3 | from unittest import mock |
6 | 4 |
|
| 5 | +import httpx |
7 | 6 | import pytest |
8 | | -from langchain_core.messages import AIMessage, AIMessageChunk |
9 | | -from requests.exceptions import HTTPError |
10 | 7 |
|
11 | | -from langchain_oci.chat_models import ( |
12 | | - ChatOCIModelDeploymentTGI, |
13 | | - ChatOCIModelDeploymentVLLM, |
14 | | -) |
15 | | - |
16 | | -CONST_MODEL_NAME = "odsc-vllm" |
17 | | -CONST_ENDPOINT = "https://oci.endpoint/ocid/predict" |
18 | | -CONST_PROMPT = "This is a prompt." |
19 | | -CONST_COMPLETION = "This is a completion." |
20 | | -CONST_COMPLETION_ROUTE = "/v1/chat/completions" |
21 | | -CONST_COMPLETION_RESPONSE = { |
22 | | - "id": "chat-123456789", |
23 | | - "object": "chat.completion", |
24 | | - "created": 123456789, |
25 | | - "model": "mistral", |
26 | | - "choices": [ |
27 | | - { |
28 | | - "index": 0, |
29 | | - "message": { |
30 | | - "role": "assistant", |
31 | | - "content": CONST_COMPLETION, |
32 | | - "tool_calls": [], |
33 | | - }, |
34 | | - "logprobs": None, |
35 | | - "finish_reason": "length", |
36 | | - "stop_reason": None, |
37 | | - } |
38 | | - ], |
39 | | - "usage": {"prompt_tokens": 115, "total_tokens": 371, "completion_tokens": 256}, |
40 | | - "prompt_logprobs": None, |
41 | | -} |
42 | | -CONST_COMPLETION_RESPONSE_TGI = {"generated_text": CONST_COMPLETION} |
43 | | -CONST_STREAM_TEMPLATE = ( |
44 | | - 'data: {"id":"chat-123456","object":"chat.completion.chunk","created":123456789,' |
45 | | - '"model":"odsc-llm","choices":[{"index":0,"delta":<DELTA>,"finish_reason":null}]}' |
46 | | -) |
47 | | -CONST_STREAM_DELTAS = ['{"role":"assistant","content":""}'] + [ |
48 | | - '{"content":" ' + word + '"}' for word in CONST_COMPLETION.split(" ") |
49 | | -] |
50 | | -CONST_STREAM_RESPONSE = ( |
51 | | - content |
52 | | - for content in [ |
53 | | - CONST_STREAM_TEMPLATE.replace("<DELTA>", delta).encode() |
54 | | - for delta in CONST_STREAM_DELTAS |
55 | | - ] |
56 | | - + [b"data: [DONE]"] |
57 | | -) |
58 | | - |
59 | | -CONST_ASYNC_STREAM_TEMPLATE = ( |
60 | | - '{"id":"chat-123456","object":"chat.completion.chunk","created":123456789,' |
61 | | - '"model":"odsc-llm","choices":[{"index":0,"delta":<DELTA>,"finish_reason":null}]}' |
62 | | -) |
63 | | -CONST_ASYNC_STREAM_RESPONSE = ( |
64 | | - CONST_ASYNC_STREAM_TEMPLATE.replace("<DELTA>", delta).encode() |
65 | | - for delta in CONST_STREAM_DELTAS |
66 | | -) |
67 | | - |
68 | | -pytestmark = pytest.mark.skipif( |
69 | | - sys.version_info < (3, 9), reason="Requires Python 3.9 or higher" |
70 | | -) |
71 | | - |
72 | | - |
73 | | -class MockResponse: |
74 | | - """Represents a mocked response.""" |
75 | | - |
76 | | - def __init__(self, json_data: Dict, status_code: int = 200): |
77 | | - self.json_data = json_data |
78 | | - self.status_code = status_code |
79 | | - |
80 | | - def raise_for_status(self) -> None: |
81 | | - """Mocked raise for status.""" |
82 | | - if 400 <= self.status_code < 600: |
83 | | - raise HTTPError() # type: ignore[call-arg] |
84 | | - |
85 | | - def json(self) -> Dict: |
86 | | - """Returns mocked json data.""" |
87 | | - return self.json_data |
88 | | - |
89 | | - def iter_lines(self, chunk_size: int = 4096) -> Generator[bytes, None, None]: |
90 | | - """Returns a generator of mocked streaming response.""" |
91 | | - return CONST_STREAM_RESPONSE |
92 | | - |
93 | | - @property |
94 | | - def text(self) -> str: |
95 | | - """Returns the mocked text representation.""" |
96 | | - return "" |
97 | | - |
98 | | - |
99 | | -def mocked_requests_post(url: str, **kwargs: Any) -> MockResponse: |
100 | | - """Method to mock post requests""" |
101 | | - |
102 | | - payload: dict = kwargs.get("json", {}) |
103 | | - messages: list = payload.get("messages", []) |
104 | | - prompt = messages[0].get("content") |
105 | | - |
106 | | - if prompt == CONST_PROMPT: |
107 | | - return MockResponse(json_data=CONST_COMPLETION_RESPONSE) |
108 | | - |
109 | | - return MockResponse( |
110 | | - json_data={}, |
111 | | - status_code=404, |
112 | | - ) |
| 8 | +from langchain_oci import ChatOCIModelDeployment |
113 | 9 |
|
114 | 10 |
|
115 | 11 | @pytest.mark.requires("ads") |
116 | 12 | @pytest.mark.requires("langchain_openai") |
117 | | -@mock.patch("ads.common.auth.default_signer", return_value=dict(signer=None)) |
118 | | -@mock.patch("requests.post", side_effect=mocked_requests_post) |
119 | | -def test_invoke_vllm(*args: Any) -> None: |
120 | | - """Tests invoking vLLM endpoint.""" |
121 | | - llm = ChatOCIModelDeploymentVLLM(endpoint=CONST_ENDPOINT, model=CONST_MODEL_NAME) |
122 | | - assert llm._headers().get("route") == CONST_COMPLETION_ROUTE |
123 | | - output = llm.invoke(CONST_PROMPT) |
124 | | - assert isinstance(output, AIMessage) |
125 | | - assert output.content == CONST_COMPLETION |
| 13 | +@mock.patch("ads.aqua.client.client.HttpxOCIAuth") |
| 14 | +def test_authentication_with_ads(*args): |
| 15 | + with mock.patch( |
| 16 | + "ads.common.auth.default_signer", return_value=dict(signer=mock.MagicMock()) |
| 17 | + ) as ads_auth: |
| 18 | + llm = ChatOCIModelDeployment(endpoint="<endpoint>", model="my-model") |
| 19 | + ads_auth.assert_called_once() |
| 20 | + assert llm.openai_api_base == "<endpoint>/v1" |
| 21 | + assert llm.model_name == "my-model" |
126 | 22 |
|
127 | 23 |
|
128 | | -@pytest.mark.requires("ads") |
129 | | -@pytest.mark.requires("langchain_openai") |
130 | | -@mock.patch("ads.common.auth.default_signer", return_value=dict(signer=None)) |
131 | | -@mock.patch("requests.post", side_effect=mocked_requests_post) |
132 | | -def test_invoke_tgi(*args: Any) -> None: |
133 | | - """Tests invoking TGI endpoint using OpenAI Spec.""" |
134 | | - llm = ChatOCIModelDeploymentTGI(endpoint=CONST_ENDPOINT, model=CONST_MODEL_NAME) |
135 | | - assert llm._headers().get("route") == CONST_COMPLETION_ROUTE |
136 | | - output = llm.invoke(CONST_PROMPT) |
137 | | - assert isinstance(output, AIMessage) |
138 | | - assert output.content == CONST_COMPLETION |
139 | | - |
140 | | - |
141 | | -@pytest.mark.requires("ads") |
142 | | -@pytest.mark.requires("langchain_openai") |
143 | | -@mock.patch("ads.common.auth.default_signer", return_value=dict(signer=None)) |
144 | | -@mock.patch("requests.post", side_effect=mocked_requests_post) |
145 | | -def test_stream_vllm(*args: Any) -> None: |
146 | | - """Tests streaming with vLLM endpoint using OpenAI spec.""" |
147 | | - llm = ChatOCIModelDeploymentVLLM( |
148 | | - endpoint=CONST_ENDPOINT, model=CONST_MODEL_NAME, streaming=True |
149 | | - ) |
150 | | - assert llm._headers().get("route") == CONST_COMPLETION_ROUTE |
151 | | - output = None |
152 | | - count = 0 |
153 | | - for chunk in llm.stream(CONST_PROMPT): |
154 | | - assert isinstance(chunk, AIMessageChunk) |
155 | | - if output is None: |
156 | | - output = chunk |
157 | | - else: |
158 | | - output += chunk |
159 | | - count += 1 |
160 | | - assert count == 5 |
161 | | - assert output is not None |
162 | | - if output is not None: |
163 | | - assert str(output.content).strip() == CONST_COMPLETION |
164 | | - |
165 | | - |
166 | | -async def mocked_async_streaming_response( |
167 | | - *args: Any, **kwargs: Any |
168 | | -) -> AsyncGenerator[bytes, None]: |
169 | | - """Returns mocked response for async streaming.""" |
170 | | - for item in CONST_ASYNC_STREAM_RESPONSE: |
171 | | - yield item |
172 | | - |
173 | | - |
174 | | -@pytest.mark.asyncio |
175 | | -@pytest.mark.requires("ads") |
176 | 24 | @pytest.mark.requires("langchain_openai") |
177 | | -@mock.patch( |
178 | | - "ads.common.auth.default_signer", return_value=dict(signer=mock.MagicMock()) |
179 | | -) |
180 | | -@mock.patch( |
181 | | - "langchain_oci.llms.oci_data_science_model_deployment_endpoint.BaseOCIModelDeployment._arequest", |
182 | | - mock.MagicMock(), |
183 | | -) |
184 | | -async def test_stream_async(*args: Any) -> None: |
185 | | - """Tests async streaming.""" |
186 | | - llm = ChatOCIModelDeploymentVLLM( |
187 | | - endpoint=CONST_ENDPOINT, model=CONST_MODEL_NAME, streaming=True |
188 | | - ) |
189 | | - assert llm._headers().get("route") == CONST_COMPLETION_ROUTE |
190 | | - with mock.patch.object( |
191 | | - llm, |
192 | | - "_aiter_sse", |
193 | | - mock.MagicMock(return_value=mocked_async_streaming_response()), |
194 | | - ): |
195 | | - chunks = [str(chunk.content) async for chunk in llm.astream(CONST_PROMPT)] |
196 | | - assert "".join(chunks).strip() == CONST_COMPLETION |
| 25 | +def test_authentication_with_custom_client(): |
| 26 | + http_client = httpx.Client() |
| 27 | + llm = ChatOCIModelDeployment(base_url="<endpoint>/v1", http_client=http_client) |
| 28 | + assert llm.http_client == http_client |
| 29 | + assert llm.openai_api_base == "<endpoint>/v1" |
| 30 | + assert llm.model_name == "odsc-llm" |
0 commit comments