Skip to content

Commit

Permalink
Use ollama Python client within OllamaLLM (#307)
Browse files Browse the repository at this point in the history
Co-authored-by: Alvaro Bartolome <[email protected]>
Co-authored-by: plaguss <[email protected]>
  • Loading branch information
3 people authored Feb 2, 2024
1 parent b7fcfb6 commit 979cf45
Show file tree
Hide file tree
Showing 7 changed files with 136 additions and 60 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ jobs:

- name: Install dependencies
if: steps.cache.outputs.cache-hit != 'true'
run: pip install -e .[dev,tests,vertexai,openai,together,argilla,mistralai]
run: pip install -e .[dev,tests,vertexai,ollama,openai,together,argilla,mistralai]
- name: Lint
run: make lint

Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ In addition, the following extras are available:
- `openai`: for using OpenAI API models via the `OpenAILLM` integration.
- `vllm`: for using [vllm](https://github.com/vllm-project/vllm) serving engine via the `vLLM` integration.
- `llama-cpp`: for using [llama-cpp-python](https://github.com/abetlen/llama-cpp-python) as Python bindings for `llama.cpp`.
- `ollama`: for using [Ollama](https://github.com/ollama/ollama) and their available models via their Python client.
- `together`: for using [Together Inference](https://www.together.ai/products) via their Python client.
- `anyscale`: for using [Anyscale endpoints](https://www.anyscale.com/endpoints).
- `ollama`: for using [Ollama](https://ollama.ai/).
Expand Down
1 change: 1 addition & 0 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ In addition, the following extras are available:
- `openai`: for using OpenAI API models via the `OpenAILLM` integration.
- `vllm`: for using [vllm](https://github.com/vllm-project/vllm) serving engine via the `vLLM` integration.
- `llama-cpp`: for using [llama-cpp-python](https://github.com/abetlen/llama-cpp-python) as Python bindings for `llama.cpp`.
- `ollama`: for using [Ollama](https://github.com/ollama/ollama) and their available models via their Python client.
- `together`: for using [Together Inference](https://www.together.ai/products) via their Python client.
- `vertexai`: for using both [Google Vertex AI](https://cloud.google.com/vertex-ai/?&gad_source=1&hl=es) offerings: their proprietary models and endpoints via their Python client [`google-cloud-aiplatform`](https://github.com/googleapis/python-aiplatform).
- `argilla`: for exporting the generated datasets to [Argilla](https://argilla.io/).
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ dev = ["black == 23.10.0", "ruff == 0.1.0", "pre-commit >= 3.5.0"]
hf-transformers = ["transformers >= 4.34.1", "torch >= 2.0.0"]
hf-inference-endpoints = ["huggingface_hub >= 0.19.0"]
llama-cpp = ["llama-cpp-python >= 0.2.0"]
ollama = ["ollama >= 0.1.4"]
openai = ["openai >= 1.0.0"]
vllm = ["vllm >= 0.2.1"]
vertexai = ["google-cloud-aiplatform >= 1.38.0"]
Expand Down
156 changes: 100 additions & 56 deletions src/distilabel/llm/ollama.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import logging
import os
from typing import TYPE_CHECKING, Any, Callable, Dict, Generator, List, Union
from urllib import error, request
from urllib import error

from tenacity import (
after_log,
Expand All @@ -30,6 +29,10 @@
from distilabel.llm.base import LLM
from distilabel.llm.utils import LLMOutput
from distilabel.logger import get_logger
from distilabel.utils.imports import _OLLAMA_AVAILABLE

if _OLLAMA_AVAILABLE:
import ollama

if TYPE_CHECKING:
from distilabel.tasks.base import Task
Expand All @@ -54,7 +57,18 @@ def __init__(
temperature: Union[float, None] = None,
top_k: Union[int, None] = None,
top_p: Union[float, None] = None,
mirostat: Union[int, None] = None,
mirostat_eta: Union[float, None] = None,
mirostat_tau: Union[float, None] = None,
num_ctx: Union[int, None] = None,
num_gqa: Union[int, None] = None,
num_gpu: Union[int, None] = None,
num_threads: Union[int, None] = None,
repeat_last_n: Union[int, None] = None,
repeat_penalty: Union[float, None] = None,
seed: Union[int, None] = None,
stop: Union[str, None] = None,
tfs_z: Union[float, None] = None,
prompt_format: Union["SupportedFormats", None] = None,
prompt_formatting_fn: Union[Callable[..., str], None] = None,
) -> None:
Expand All @@ -72,9 +86,31 @@ def __init__(
Defaults to `None`.
top_p (float, optional): the top-p value to be used for generation.
Defaults to `None`.
mirostat (int, optional): the Mirostat value to enable it or set the version.
Defaults to `None`.
mirostat_eta (float, optional): the eta value to be used for Mirostat.
Defaults to `None`.
mirostat_tau (float, optional): the tau value to be used for Mirostat.
Defaults to `None`.
num_ctx (int, optional): the number of contexts to be used for generation.
Defaults to `None`.
num_gqa (int, optional): the number of GQA to be used for generation.
Defaults to `None`.
num_gpu (int, optional): the number of GPUs to be used for generation.
Defaults to `None`.
num_threads (Union[int, None], optional): the number of threads to be used
for parallel generation. If `None`, no parallel generation will be performed.
Defaults to `None`.
repeat_last_n (Union[int, None], optional): the number of tokens to be used
for RepeatLastN. Defaults to `None`.
repeat_penalty (Union[float, None], optional): the penalty to be used for RepeatLastN.
Defaults to `None`.
seed (Union[int, None], optional): the seed to be used for generation.
Defaults to `None`.
stop (Union[str, None], optional): the stop token to be used for generation. If `None`,
no stop token will be used. Defaults to `None`.
tfs_z (Union[float, None], optional): the z value to be used for TFS.
Defaults to `None`.
prompt_format (Union[SupportedFormats, None], optional): the format to be used
for the prompt. If `None`, the default format of the task will be used, available
formats are `openai`, `chatml`, `llama2`, `zephyr`, and `default`. Defaults to `None`,
Expand Down Expand Up @@ -106,6 +142,17 @@ def __init__(
self.temperature = temperature
self.top_k = top_k
self.top_p = top_p
self.mirostat = mirostat
self.mirostat_eta = mirostat_eta
self.mirostat_tau = mirostat_tau
self.num_ctx = num_ctx
self.num_gqa = num_gqa
self.num_gpu = num_gpu
self.repeat_last_n = repeat_last_n
self.repeat_penalty = repeat_penalty
self.seed = seed
self.stop = stop
self.tfs_z = tfs_z

self._api_available()
self._api_model_available()
Expand All @@ -116,23 +163,22 @@ def model_name(self) -> str:
return self.model

def _api_available(self):
"""Calls GET {OLLAMA_HOST}"""
msg = f"Could not connect to Ollama as {self.OLLAMA_HOST}. Check https://github.com/ollama/ollama for deployment guide."
"""Checks if the Ollama API is available."""
try:
response = request.urlopen(self.OLLAMA_HOST)
if response.getcode() != 200:
raise Exception
except Exception as e:
raise ValueError(msg) from e
ollama.list()
except ollama.ResponseError as e:
raise ValueError(
f"Could not connect to Ollama at {self.OLLAMA_HOST}. Check https://github.com/ollama/ollama-python/tree/main for deployment guide."
) from e

def _api_model_available(self):
msg = f"Model {self.model} is not available. Run `ollama run {self.model}` to serve the model."
"""Checks if the Ollama model is available"""
try:
self._text_generation_with_backoff(
prompt=[{"role": "user", "content": "hi"}], max_tokens=1
)
except Exception as e:
raise ValueError(msg) from e
ollama.show(self.model)
except ollama.ResponseError as e:
raise ValueError(
f"Model {self.model} is not available. Run `ollama run {self.model}` to serve the model."
) from e

@retry(
retry=retry_if_exception_type(_OLLAMA_API_RETRY_ON_EXCEPTIONS),
Expand All @@ -147,49 +193,36 @@ def _api_model_available(self):
def _text_generation_with_backoff(
self, prompt: List[Dict[str, str]], **kwargs
) -> str:
"""Calls POST {OLLAMA_HOST}/api/chat"""
# Request payload
payload = {
"model": self.model,
"messages": prompt,
"stream": False,
}
options = {
"num_predict": kwargs.get("max_new_tokens") or self.max_new_tokens,
"temperature": self.temperature,
"top_p": self.top_p,
"top_k": self.top_k,
}
# remove None values
options = {k: v for k, v in options.items() if v is not None}
payload["options"] = options

# Convert payload to JSON
data = json.dumps(payload).encode("utf-8")

# Create the request
url = f"{self.OLLAMA_HOST}/api/chat"
req = request.Request(
url, data=data, headers={"Content-Type": "application/json"}
)
with request.urlopen(req) as response:
# Check if the request was successful (status code 200)
if response.getcode() == 200:
# Parse and return the response JSON
return json.loads(response.read().decode("utf-8"))
elif response.getcode() >= 500:
# If the request failed, try again with backoff
raise error.HTTPError(
url=url,
code=response.getcode(),
msg=f"Server Error {response.getcode()}",
hdrs=response.getheaders(),
fp=None,
)
"""Generates text using the Ollama API with backoff."""
try:
return ollama.chat(
model=self.model,
messages=prompt,
options={
"num_predict": self.max_new_tokens,
"temperature": self.temperature,
"top_p": self.top_p,
"top_k": self.top_k,
"mirostat": self.mirostat,
"mirostat_eta": self.mirostat_eta,
"mirostat_tau": self.mirostat_tau,
"num_ctx": self.num_ctx,
"num_gqa": self.num_gqa,
"num_gpu": self.num_gpu,
"repeat_last_n": self.repeat_last_n,
"repeat_penalty": self.repeat_penalty,
"seed": self.seed,
"stop": self.stop,
"tfs_z": self.tfs_z,
},
)
except ollama.ResponseError as e:
if e.status_code >= 500:
raise
else:
raise ValueError(
f"Ollama API request failed with status_code {response.getcode()}."
)
f"Ollama API request failed with status_code {e.status_code}."
) from e

def __rich_repr__(self) -> Generator[Any, None, None]:
yield from super().__rich_repr__()
Expand All @@ -201,6 +234,17 @@ def __rich_repr__(self) -> Generator[Any, None, None]:
"temperature": self.temperature,
"top_k": self.top_k,
"top_p": self.top_p,
"mirostat": self.mirostat,
"mirostat_eta": self.mirostat_eta,
"mirostat_tau": self.mirostat_tau,
"num_ctx": self.num_ctx,
"num_gqa": self.num_gqa,
"num_gpu": self.num_gpu,
"repeat_last_n": self.repeat_last_n,
"repeat_penalty": self.repeat_penalty,
"seed": self.seed,
"stop": self.stop,
"tfs_z": self.tfs_z,
},
)

Expand Down
5 changes: 4 additions & 1 deletion src/distilabel/utils/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def _check_package_is_available(


_ARGILLA_AVAILABLE = (
_check_package_is_available("argilla", min_version="1.22.0", greater_or_equal=True)
_check_package_is_available("argilla", min_version="1.23.0", greater_or_equal=True)
and _check_package_is_available(
"sentence-transformers", min_version="2.0.0", greater_or_equal=True
)
Expand All @@ -105,6 +105,9 @@ def _check_package_is_available(
_LLAMA_CPP_AVAILABLE = _check_package_is_available(
"llama_cpp_python", min_version="0.2.0", greater_or_equal=True
)
_OLLAMA_AVAILABLE = _check_package_is_available(
"ollama", min_version="0.1.4", greater_or_equal=True
)
_VLLM_AVAILABLE = _check_package_is_available(
"vllm", min_version="0.2.1", greater_or_equal=True
)
Expand Down
30 changes: 28 additions & 2 deletions tests/llm/test_ollama.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from unittest.mock import Mock

import httpx
import pytest
from distilabel.llm import OllamaLLM
from distilabel.tasks.text_generation.base import TextGenerationTask
Expand All @@ -22,13 +23,13 @@
@pytest.fixture(scope="module")
def mock_ollama_llm():
task = TextGenerationTask()
with pytest.raises(ValueError):
with pytest.raises(httpx.ConnectError):
OllamaLLM(
model="test_model",
task=task,
)
OllamaLLM._api_available = Mock(return_value=None)
with pytest.raises(ValueError):
with pytest.raises(httpx.ConnectError):
OllamaLLM(
model="test_model",
task=task,
Expand All @@ -55,7 +56,18 @@ def mock_ollama_llm():
temperature=0.8,
top_k=5,
top_p=0.9,
mirostat=0,
mirostat_eta=0.1,
mirostat_tau=5.0,
num_ctx=2048,
num_gqa=None,
num_gpu=None,
num_threads=4,
repeat_last_n=64,
repeat_penalty=1.1,
seed=0,
stop=None,
tfs_z=1,
prompt_format="default",
prompt_formatting_fn=None,
)
Expand All @@ -67,6 +79,20 @@ def test_ollama_llm_init(mock_ollama_llm: OllamaLLM):
assert mock_ollama_llm.temperature == 0.8
assert mock_ollama_llm.top_k == 5
assert mock_ollama_llm.top_p == 0.9
assert mock_ollama_llm.mirostat == 0
assert mock_ollama_llm.mirostat_eta == 0.1
assert mock_ollama_llm.mirostat_tau == 5.0
assert mock_ollama_llm.num_ctx == 2048
assert mock_ollama_llm.num_gqa is None
assert mock_ollama_llm.num_gpu is None
assert mock_ollama_llm.num_threads == 4
assert mock_ollama_llm.repeat_last_n == 64
assert mock_ollama_llm.repeat_penalty == 1.1
assert mock_ollama_llm.seed == 0
assert mock_ollama_llm.stop is None
assert mock_ollama_llm.tfs_z == 1
assert mock_ollama_llm.prompt_format == "default"
assert mock_ollama_llm.prompt_formatting_fn is None


def test_ollama_llm_inherits_from_task(mock_ollama_llm):
Expand Down

0 comments on commit 979cf45

Please sign in to comment.