Skip to content

Commit 19a9b7c

Browse files
committed
Update tests.
1 parent e1ac5ed commit 19a9b7c

File tree

3 files changed

+17
-450
lines changed

3 files changed

+17
-450
lines changed
Lines changed: 17 additions & 183 deletions
Original file line numberDiff line numberDiff line change
@@ -1,196 +1,30 @@
11
"""Test Chat model for OCI Data Science Model Deployment Endpoint."""
22

3-
import sys
4-
from typing import Any, AsyncGenerator, Dict, Generator
53
from unittest import mock
64

5+
import httpx
76
import pytest
8-
from langchain_core.messages import AIMessage, AIMessageChunk
9-
from requests.exceptions import HTTPError
107

11-
from langchain_oci.chat_models import (
12-
ChatOCIModelDeploymentTGI,
13-
ChatOCIModelDeploymentVLLM,
14-
)
15-
16-
CONST_MODEL_NAME = "odsc-vllm"
17-
CONST_ENDPOINT = "https://oci.endpoint/ocid/predict"
18-
CONST_PROMPT = "This is a prompt."
19-
CONST_COMPLETION = "This is a completion."
20-
CONST_COMPLETION_ROUTE = "/v1/chat/completions"
21-
CONST_COMPLETION_RESPONSE = {
22-
"id": "chat-123456789",
23-
"object": "chat.completion",
24-
"created": 123456789,
25-
"model": "mistral",
26-
"choices": [
27-
{
28-
"index": 0,
29-
"message": {
30-
"role": "assistant",
31-
"content": CONST_COMPLETION,
32-
"tool_calls": [],
33-
},
34-
"logprobs": None,
35-
"finish_reason": "length",
36-
"stop_reason": None,
37-
}
38-
],
39-
"usage": {"prompt_tokens": 115, "total_tokens": 371, "completion_tokens": 256},
40-
"prompt_logprobs": None,
41-
}
42-
CONST_COMPLETION_RESPONSE_TGI = {"generated_text": CONST_COMPLETION}
43-
CONST_STREAM_TEMPLATE = (
44-
'data: {"id":"chat-123456","object":"chat.completion.chunk","created":123456789,'
45-
'"model":"odsc-llm","choices":[{"index":0,"delta":<DELTA>,"finish_reason":null}]}'
46-
)
47-
CONST_STREAM_DELTAS = ['{"role":"assistant","content":""}'] + [
48-
'{"content":" ' + word + '"}' for word in CONST_COMPLETION.split(" ")
49-
]
50-
CONST_STREAM_RESPONSE = (
51-
content
52-
for content in [
53-
CONST_STREAM_TEMPLATE.replace("<DELTA>", delta).encode()
54-
for delta in CONST_STREAM_DELTAS
55-
]
56-
+ [b"data: [DONE]"]
57-
)
58-
59-
CONST_ASYNC_STREAM_TEMPLATE = (
60-
'{"id":"chat-123456","object":"chat.completion.chunk","created":123456789,'
61-
'"model":"odsc-llm","choices":[{"index":0,"delta":<DELTA>,"finish_reason":null}]}'
62-
)
63-
CONST_ASYNC_STREAM_RESPONSE = (
64-
CONST_ASYNC_STREAM_TEMPLATE.replace("<DELTA>", delta).encode()
65-
for delta in CONST_STREAM_DELTAS
66-
)
67-
68-
pytestmark = pytest.mark.skipif(
69-
sys.version_info < (3, 9), reason="Requires Python 3.9 or higher"
70-
)
71-
72-
73-
class MockResponse:
74-
"""Represents a mocked response."""
75-
76-
def __init__(self, json_data: Dict, status_code: int = 200):
77-
self.json_data = json_data
78-
self.status_code = status_code
79-
80-
def raise_for_status(self) -> None:
81-
"""Mocked raise for status."""
82-
if 400 <= self.status_code < 600:
83-
raise HTTPError() # type: ignore[call-arg]
84-
85-
def json(self) -> Dict:
86-
"""Returns mocked json data."""
87-
return self.json_data
88-
89-
def iter_lines(self, chunk_size: int = 4096) -> Generator[bytes, None, None]:
90-
"""Returns a generator of mocked streaming response."""
91-
return CONST_STREAM_RESPONSE
92-
93-
@property
94-
def text(self) -> str:
95-
"""Returns the mocked text representation."""
96-
return ""
97-
98-
99-
def mocked_requests_post(url: str, **kwargs: Any) -> MockResponse:
100-
"""Method to mock post requests"""
101-
102-
payload: dict = kwargs.get("json", {})
103-
messages: list = payload.get("messages", [])
104-
prompt = messages[0].get("content")
105-
106-
if prompt == CONST_PROMPT:
107-
return MockResponse(json_data=CONST_COMPLETION_RESPONSE)
108-
109-
return MockResponse(
110-
json_data={},
111-
status_code=404,
112-
)
8+
from langchain_oci import ChatOCIModelDeployment
1139

11410

11511
@pytest.mark.requires("ads")
11612
@pytest.mark.requires("langchain_openai")
117-
@mock.patch("ads.common.auth.default_signer", return_value=dict(signer=None))
118-
@mock.patch("requests.post", side_effect=mocked_requests_post)
119-
def test_invoke_vllm(*args: Any) -> None:
120-
"""Tests invoking vLLM endpoint."""
121-
llm = ChatOCIModelDeploymentVLLM(endpoint=CONST_ENDPOINT, model=CONST_MODEL_NAME)
122-
assert llm._headers().get("route") == CONST_COMPLETION_ROUTE
123-
output = llm.invoke(CONST_PROMPT)
124-
assert isinstance(output, AIMessage)
125-
assert output.content == CONST_COMPLETION
13+
@mock.patch("ads.aqua.client.client.HttpxOCIAuth")
14+
def test_authentication_with_ads(*args):
15+
with mock.patch(
16+
"ads.common.auth.default_signer", return_value=dict(signer=mock.MagicMock())
17+
) as ads_auth:
18+
llm = ChatOCIModelDeployment(endpoint="<endpoint>", model="my-model")
19+
ads_auth.assert_called_once()
20+
assert llm.openai_api_base == "<endpoint>/v1"
21+
assert llm.model_name == "my-model"
12622

12723

128-
@pytest.mark.requires("ads")
129-
@pytest.mark.requires("langchain_openai")
130-
@mock.patch("ads.common.auth.default_signer", return_value=dict(signer=None))
131-
@mock.patch("requests.post", side_effect=mocked_requests_post)
132-
def test_invoke_tgi(*args: Any) -> None:
133-
"""Tests invoking TGI endpoint using OpenAI Spec."""
134-
llm = ChatOCIModelDeploymentTGI(endpoint=CONST_ENDPOINT, model=CONST_MODEL_NAME)
135-
assert llm._headers().get("route") == CONST_COMPLETION_ROUTE
136-
output = llm.invoke(CONST_PROMPT)
137-
assert isinstance(output, AIMessage)
138-
assert output.content == CONST_COMPLETION
139-
140-
141-
@pytest.mark.requires("ads")
142-
@pytest.mark.requires("langchain_openai")
143-
@mock.patch("ads.common.auth.default_signer", return_value=dict(signer=None))
144-
@mock.patch("requests.post", side_effect=mocked_requests_post)
145-
def test_stream_vllm(*args: Any) -> None:
146-
"""Tests streaming with vLLM endpoint using OpenAI spec."""
147-
llm = ChatOCIModelDeploymentVLLM(
148-
endpoint=CONST_ENDPOINT, model=CONST_MODEL_NAME, streaming=True
149-
)
150-
assert llm._headers().get("route") == CONST_COMPLETION_ROUTE
151-
output = None
152-
count = 0
153-
for chunk in llm.stream(CONST_PROMPT):
154-
assert isinstance(chunk, AIMessageChunk)
155-
if output is None:
156-
output = chunk
157-
else:
158-
output += chunk
159-
count += 1
160-
assert count == 5
161-
assert output is not None
162-
if output is not None:
163-
assert str(output.content).strip() == CONST_COMPLETION
164-
165-
166-
async def mocked_async_streaming_response(
167-
*args: Any, **kwargs: Any
168-
) -> AsyncGenerator[bytes, None]:
169-
"""Returns mocked response for async streaming."""
170-
for item in CONST_ASYNC_STREAM_RESPONSE:
171-
yield item
172-
173-
174-
@pytest.mark.asyncio
175-
@pytest.mark.requires("ads")
17624
@pytest.mark.requires("langchain_openai")
177-
@mock.patch(
178-
"ads.common.auth.default_signer", return_value=dict(signer=mock.MagicMock())
179-
)
180-
@mock.patch(
181-
"langchain_oci.llms.oci_data_science_model_deployment_endpoint.BaseOCIModelDeployment._arequest",
182-
mock.MagicMock(),
183-
)
184-
async def test_stream_async(*args: Any) -> None:
185-
"""Tests async streaming."""
186-
llm = ChatOCIModelDeploymentVLLM(
187-
endpoint=CONST_ENDPOINT, model=CONST_MODEL_NAME, streaming=True
188-
)
189-
assert llm._headers().get("route") == CONST_COMPLETION_ROUTE
190-
with mock.patch.object(
191-
llm,
192-
"_aiter_sse",
193-
mock.MagicMock(return_value=mocked_async_streaming_response()),
194-
):
195-
chunks = [str(chunk.content) async for chunk in llm.astream(CONST_PROMPT)]
196-
assert "".join(chunks).strip() == CONST_COMPLETION
25+
def test_authentication_with_custom_client():
26+
http_client = httpx.Client()
27+
llm = ChatOCIModelDeployment(base_url="<endpoint>/v1", http_client=http_client)
28+
assert llm.http_client == http_client
29+
assert llm.openai_api_base == "<endpoint>/v1"
30+
assert llm.model_name == "odsc-llm"

libs/oci/tests/unit_tests/chat_models/test_oci_model_deployment_endpoint.py

Lines changed: 0 additions & 98 deletions
This file was deleted.

0 commit comments

Comments
 (0)