diff --git a/docs/sections/learn/llms/index.md b/docs/sections/learn/llms/index.md index cec17806e..fb890a90b 100644 --- a/docs/sections/learn/llms/index.md +++ b/docs/sections/learn/llms/index.md @@ -98,6 +98,8 @@ Once those methods have been implemented, then the custom LLM will be ready to b ```python from typing import Any +from pydantic import validate_call + from distilabel.llms import AsyncLLM, LLM from distilabel.llms.typing import GenerateOutput, HiddenState from distilabel.steps.tasks.typing import ChatType @@ -107,7 +109,8 @@ class CustomLLM(LLM): def model_name(self) -> str: return "my-model" - def generate(self, inputs: List[ChatType], num_generations: int = 1) -> List[GenerateOutput]: + @validate_call + def generate(self, inputs: List[ChatType], num_generations: int = 1, **kwargs: Any) -> List[GenerateOutput]: for _ in range(num_generations): ... @@ -120,6 +123,7 @@ class CustomAsyncLLM(AsyncLLM): def model_name(self) -> str: return "my-model" + @validate_call async def agenerate(self, input: ChatType, num_generations: int = 1, **kwargs: Any) -> GenerateOutput: for _ in range(num_generations): ... @@ -128,6 +132,11 @@ class CustomAsyncLLM(AsyncLLM): ... ``` +`generate` and `agenerate` keyword arguments (but `input` and `num_generations`) are considered as `RuntimeParameter`s, so a value can be passed to them via the `parameters` argument of the `Pipeline.run` method. + +!!! NOTE + To have the arguments of the `generate` and `agenerate` coerced to the expected types, the `validate_call` decorator is used, which will automatically coerce the arguments to the expected types, and raise an error if the types are not correct. This is specially useful when providing a value for an argument of `generate` or `agenerate` from the CLI, since the CLI will always provide the arguments as strings. + ## Available LLMs Here's a list with the available LLMs that can be used within the `distilabel` library: diff --git a/src/distilabel/llms/anthropic.py b/src/distilabel/llms/anthropic.py index a863bd7a9..f472aca66 100644 --- a/src/distilabel/llms/anthropic.py +++ b/src/distilabel/llms/anthropic.py @@ -27,18 +27,18 @@ ) from httpx import AsyncClient -from pydantic import Field, PrivateAttr, SecretStr +from pydantic import Field, PrivateAttr, SecretStr, validate_call from typing_extensions import override from distilabel.llms.base import AsyncLLM +from distilabel.llms.typing import GenerateOutput from distilabel.mixins.runtime_parameters import RuntimeParameter +from distilabel.steps.tasks.typing import ChatType from distilabel.utils.itertools import grouper if TYPE_CHECKING: from anthropic import AsyncAnthropic - from distilabel.llms.typing import GenerateOutput - from distilabel.steps.tasks.typing import ChatType _ANTHROPIC_API_KEY_ENV_VAR_NAME = "ANTHROPIC_API_KEY" @@ -149,15 +149,16 @@ def model_name(self) -> str: """Returns the model name used for the LLM.""" return self.model + @validate_call async def agenerate( # type: ignore self, - input: "ChatType", + input: ChatType, max_tokens: int = 128, stop_sequences: Union[List[str], None] = None, temperature: float = 1.0, top_p: Union[float, None] = None, top_k: Union[int, None] = None, - ) -> "GenerateOutput": + ) -> GenerateOutput: """Generates a response asynchronously, using the [Anthropic Async API definition](https://github.com/anthropics/anthropic-sdk-python). Args: @@ -173,14 +174,14 @@ async def agenerate( # type: ignore """ from anthropic._types import NOT_GIVEN - completion = await self._aclient.messages.create( + completion = await self._aclient.messages.create( # type: ignore model=self.model, system=( input.pop(0)["content"] if input and input[0]["role"] == "system" else NOT_GIVEN ), - messages=input, + messages=input, # type: ignore max_tokens=max_tokens, stream=False, stop_sequences=NOT_GIVEN if stop_sequences is None else stop_sequences, diff --git a/src/distilabel/llms/cohere.py b/src/distilabel/llms/cohere.py index cde3dd578..d2774b591 100644 --- a/src/distilabel/llms/cohere.py +++ b/src/distilabel/llms/cohere.py @@ -25,18 +25,18 @@ Union, ) -from pydantic import Field, PrivateAttr, SecretStr +from pydantic import Field, PrivateAttr, SecretStr, validate_call from typing_extensions import override from distilabel.llms.base import AsyncLLM from distilabel.mixins.runtime_parameters import RuntimeParameter +from distilabel.steps.tasks.typing import ChatType from distilabel.utils.itertools import grouper if TYPE_CHECKING: from cohere import AsyncClient, ChatMessage from distilabel.llms.typing import GenerateOutput - from distilabel.steps.tasks.typing import ChatType _COHERE_API_KEY_ENV_VAR_NAME = "COHERE_API_KEY" @@ -153,10 +153,10 @@ def _format_chat_to_cohere( return system, chat_history, message - @override + @validate_call async def agenerate( # type: ignore self, - input: "ChatType", + input: ChatType, temperature: Optional[float] = None, max_tokens: Optional[int] = None, k: Optional[int] = None, diff --git a/src/distilabel/llms/huggingface/inference_endpoints.py b/src/distilabel/llms/huggingface/inference_endpoints.py index 3c4fef1db..d282b99a5 100644 --- a/src/distilabel/llms/huggingface/inference_endpoints.py +++ b/src/distilabel/llms/huggingface/inference_endpoints.py @@ -16,11 +16,20 @@ import os from typing import TYPE_CHECKING, Any, List, Optional, Union -from pydantic import Field, PrivateAttr, SecretStr, ValidationError, model_validator +from pydantic import ( + Field, + PrivateAttr, + SecretStr, + ValidationError, + model_validator, + validate_call, +) from typing_extensions import override from distilabel.llms.base import AsyncLLM +from distilabel.llms.typing import GenerateOutput from distilabel.mixins.runtime_parameters import RuntimeParameter +from distilabel.steps.tasks.typing import ChatType from distilabel.utils.itertools import grouper if TYPE_CHECKING: @@ -28,8 +37,6 @@ from openai import AsyncOpenAI from transformers import PreTrainedTokenizer - from distilabel.llms.typing import GenerateOutput - from distilabel.steps.tasks.typing import ChatType _INFERENCE_ENDPOINTS_API_KEY_ENV_VAR_NAME = "HF_TOKEN" @@ -220,7 +227,7 @@ async def _openai_agenerate( presence_penalty: float = 0.0, temperature: float = 1.0, top_p: Optional[float] = None, - ) -> "GenerateOutput": + ) -> GenerateOutput: """Generates completions for the given input using the OpenAI async client.""" completion = await self._aclient.chat.completions.create( # type: ignore messages=input, # type: ignore @@ -241,9 +248,10 @@ async def _openai_agenerate( return [completion.choices[0].message.content] # TODO: add `num_generations` parameter once either TGI or `AsyncInferenceClient` allows `n` parameter + @validate_call async def agenerate( # type: ignore self, - input: "ChatType", + input: ChatType, max_new_tokens: int = 128, frequency_penalty: float = 0.0, presence_penalty: float = 0.0, diff --git a/src/distilabel/llms/huggingface/transformers.py b/src/distilabel/llms/huggingface/transformers.py index 697c67722..83263cb86 100644 --- a/src/distilabel/llms/huggingface/transformers.py +++ b/src/distilabel/llms/huggingface/transformers.py @@ -15,19 +15,20 @@ import os from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union -from pydantic import PrivateAttr +from pydantic import PrivateAttr, validate_call from distilabel.llms.base import LLM from distilabel.llms.chat_templates import CHATML_TEMPLATE from distilabel.llms.mixins import CudaDevicePlacementMixin +from distilabel.llms.typing import GenerateOutput +from distilabel.steps.tasks.typing import ChatType if TYPE_CHECKING: from transformers import Pipeline from transformers.modeling_utils import PreTrainedModel from transformers.tokenization_utils import PreTrainedTokenizer - from distilabel.llms.typing import GenerateOutput, HiddenState - from distilabel.steps.tasks.typing import ChatType + from distilabel.llms.typing import HiddenState class TransformersLLM(LLM, CudaDevicePlacementMixin): @@ -129,9 +130,10 @@ def prepare_input(self, input: "ChatType") -> str: add_generation_prompt=True, ) + @validate_call def generate( # type: ignore self, - inputs: List["ChatType"], + inputs: List[ChatType], num_generations: int = 1, max_new_tokens: int = 128, temperature: float = 0.1, @@ -139,7 +141,7 @@ def generate( # type: ignore top_p: float = 1.0, top_k: int = 0, do_sample: bool = True, - ) -> List["GenerateOutput"]: + ) -> List[GenerateOutput]: """Generates `num_generations` responses for each input using the text generation pipeline. diff --git a/src/distilabel/llms/litellm.py b/src/distilabel/llms/litellm.py index 5ce9e09bd..0fff0c3ea 100644 --- a/src/distilabel/llms/litellm.py +++ b/src/distilabel/llms/litellm.py @@ -15,17 +15,16 @@ import logging from typing import TYPE_CHECKING, Callable, List, Optional, Union -from pydantic import Field, PrivateAttr +from pydantic import Field, PrivateAttr, validate_call from distilabel.llms.base import AsyncLLM +from distilabel.llms.typing import GenerateOutput from distilabel.mixins.runtime_parameters import RuntimeParameter +from distilabel.steps.tasks.typing import ChatType if TYPE_CHECKING: from litellm import Choices - from distilabel.llms.typing import GenerateOutput - from distilabel.steps.tasks.typing import ChatType - class LiteLLM(AsyncLLM): """LiteLLM implementation running the async API client. @@ -74,9 +73,10 @@ def model_name(self) -> str: """Returns the model name used for the LLM.""" return self.model + @validate_call async def agenerate( # type: ignore self, - input: "ChatType", + input: ChatType, num_generations: int = 1, functions: Optional[List] = None, function_call: Optional[str] = None, @@ -96,7 +96,7 @@ async def agenerate( # type: ignore mock_response: Optional[str] = None, force_timeout: Optional[int] = 600, custom_llm_provider: Optional[str] = None, - ) -> "GenerateOutput": + ) -> GenerateOutput: """Generates `num_generations` responses for the given input using the [LiteLLM async client](https://github.com/BerriAI/litellm). Args: diff --git a/src/distilabel/llms/llamacpp.py b/src/distilabel/llms/llamacpp.py index f74a5c9e3..36c163b43 100644 --- a/src/distilabel/llms/llamacpp.py +++ b/src/distilabel/llms/llamacpp.py @@ -14,17 +14,16 @@ from typing import TYPE_CHECKING, List, Optional -from pydantic import Field, FilePath, PrivateAttr +from pydantic import Field, FilePath, PrivateAttr, validate_call from distilabel.llms.base import LLM +from distilabel.llms.typing import GenerateOutput from distilabel.mixins.runtime_parameters import RuntimeParameter +from distilabel.steps.tasks.typing import ChatType if TYPE_CHECKING: from llama_cpp import CreateChatCompletionResponse, Llama - from distilabel.llms.typing import GenerateOutput - from distilabel.steps.tasks.typing import ChatType - class LlamaCppLLM(LLM): """llama.cpp LLM implementation running the Python bindings for the C++ code. @@ -84,16 +83,17 @@ def model_name(self) -> str: """Returns the model name used for the LLM.""" return self._model.model_path # type: ignore + @validate_call def generate( # type: ignore self, - inputs: List["ChatType"], + inputs: List[ChatType], num_generations: int = 1, max_new_tokens: int = 128, frequency_penalty: float = 0.0, presence_penalty: float = 0.0, temperature: float = 1.0, top_p: float = 1.0, - ) -> List["GenerateOutput"]: + ) -> List[GenerateOutput]: """Generates `num_generations` responses for the given input using the Llama model. Args: diff --git a/src/distilabel/llms/mistral.py b/src/distilabel/llms/mistral.py index 46d56f2a8..d05d9d3f6 100644 --- a/src/distilabel/llms/mistral.py +++ b/src/distilabel/llms/mistral.py @@ -16,18 +16,18 @@ import os from typing import TYPE_CHECKING, Any, List, Optional -from pydantic import Field, PrivateAttr, SecretStr +from pydantic import Field, PrivateAttr, SecretStr, validate_call from typing_extensions import override from distilabel.llms.base import AsyncLLM +from distilabel.llms.typing import GenerateOutput from distilabel.mixins.runtime_parameters import RuntimeParameter +from distilabel.steps.tasks.typing import ChatType from distilabel.utils.itertools import grouper if TYPE_CHECKING: from mistralai.async_client import MistralAsyncClient - from distilabel.llms.typing import GenerateOutput - from distilabel.steps.tasks.typing import ChatType _MISTRALAI_API_KEY_ENV_VAR_NAME = "MISTRAL_API_KEY" @@ -113,13 +113,14 @@ def model_name(self) -> str: return self.model # TODO: add `num_generations` parameter once Mistral client allows `n` parameter + @validate_call async def agenerate( # type: ignore self, - input: "ChatType", + input: ChatType, max_new_tokens: Optional[int] = None, temperature: Optional[float] = None, top_p: Optional[float] = None, - ) -> "GenerateOutput": + ) -> GenerateOutput: """Generates `num_generations` responses for the given input using the MistralAI async client. diff --git a/src/distilabel/llms/ollama.py b/src/distilabel/llms/ollama.py index 410176804..fb06f1eed 100644 --- a/src/distilabel/llms/ollama.py +++ b/src/distilabel/llms/ollama.py @@ -12,17 +12,57 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING, List, Literal, Optional, Union +from typing import TYPE_CHECKING, List, Literal, Optional, Sequence, Union -from pydantic import Field, PrivateAttr +from pydantic import Field, PrivateAttr, validate_call +from typing_extensions import TypedDict from distilabel.llms.base import AsyncLLM from distilabel.mixins.runtime_parameters import RuntimeParameter +from distilabel.steps.tasks.typing import ChatType if TYPE_CHECKING: - from ollama import AsyncClient, Options - - from distilabel.steps.tasks.typing import ChatType + from ollama import AsyncClient + + +# Copied from `ollama._types.Options` +class Options(TypedDict, total=False): + # load time options + numa: bool + num_ctx: int + num_batch: int + num_gqa: int + num_gpu: int + main_gpu: int + low_vram: bool + f16_kv: bool + logits_all: bool + vocab_only: bool + use_mmap: bool + use_mlock: bool + embedding_only: bool + rope_frequency_base: float + rope_frequency_scale: float + num_thread: int + + # runtime options + num_keep: int + seed: int + num_predict: int + top_k: int + top_p: float + tfs_z: float + typical_p: float + repeat_last_n: int + temperature: float + repeat_penalty: float + presence_penalty: float + frequency_penalty: float + mirostat: int + mirostat_tau: float + mirostat_eta: float + penalize_newline: bool + stop: Sequence[str] class OllamaLLM(AsyncLLM): @@ -74,12 +114,14 @@ def model_name(self) -> str: """Returns the model name used for the LLM.""" return self.model + @validate_call async def agenerate( # type: ignore self, - input: "ChatType", + input: ChatType, num_generations: int = 1, format: Literal["", "json"] = "", - options: Union["Options", None] = None, + # TODO: include relevant options from `Options` in `agenerate` method. + options: Union[Options, None] = None, keep_alive: Union[bool, None] = None, ) -> List[str]: """ diff --git a/src/distilabel/llms/openai.py b/src/distilabel/llms/openai.py index 72d7b338b..87a0525b8 100644 --- a/src/distilabel/llms/openai.py +++ b/src/distilabel/llms/openai.py @@ -15,16 +15,16 @@ import os from typing import TYPE_CHECKING, Optional -from pydantic import Field, PrivateAttr, SecretStr +from pydantic import Field, PrivateAttr, SecretStr, validate_call from distilabel.llms.base import AsyncLLM +from distilabel.llms.typing import GenerateOutput from distilabel.mixins.runtime_parameters import RuntimeParameter +from distilabel.steps.tasks.typing import ChatType if TYPE_CHECKING: from openai import AsyncOpenAI - from distilabel.llms.typing import GenerateOutput - from distilabel.steps.tasks.typing import ChatType _OPENAI_API_KEY_ENV_VAR_NAME = "OPENAI_API_KEY" @@ -110,16 +110,17 @@ def model_name(self) -> str: """Returns the model name used for the LLM.""" return self.model + @validate_call async def agenerate( # type: ignore self, - input: "ChatType", + input: ChatType, num_generations: int = 1, max_new_tokens: int = 128, frequency_penalty: float = 0.0, presence_penalty: float = 0.0, temperature: float = 1.0, top_p: float = 1.0, - ) -> "GenerateOutput": + ) -> GenerateOutput: """Generates `num_generations` responses for the given input using the OpenAI async client. diff --git a/src/distilabel/llms/vertexai.py b/src/distilabel/llms/vertexai.py index 1dd7fa654..364769b5d 100644 --- a/src/distilabel/llms/vertexai.py +++ b/src/distilabel/llms/vertexai.py @@ -12,22 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING, List, Optional +from typing import TYPE_CHECKING, Any, Dict, List, Optional -from pydantic import PrivateAttr +from pydantic import PrivateAttr, validate_call from distilabel.llms.base import AsyncLLM +from distilabel.llms.typing import GenerateOutput +from distilabel.steps.tasks.typing import ChatType if TYPE_CHECKING: - from vertexai.generative_models import ( - Content, - GenerativeModel, - SafetySettingsType, - Tool, - ) - - from distilabel.llms.typing import GenerateOutput - from distilabel.steps.tasks.typing import ChatType + from vertexai.generative_models import Content, GenerativeModel def _is_gemini_model(model: str) -> bool: @@ -113,18 +107,19 @@ def _chattype_to_content(self, input: "ChatType") -> List["Content"]: ) return contents + @validate_call async def agenerate( # type: ignore self, - input: "ChatType", + input: ChatType, num_generations: int = 1, temperature: Optional[float] = None, top_p: Optional[float] = None, top_k: Optional[int] = None, max_output_tokens: Optional[int] = None, stop_sequences: Optional[List[str]] = None, - safety_settings: Optional["SafetySettingsType"] = None, - tools: Optional[List["Tool"]] = None, - ) -> "GenerateOutput": + safety_settings: Optional[Dict[str, Any]] = None, + tools: Optional[List[Dict[str, Any]]] = None, + ) -> GenerateOutput: """Generates `num_generations` responses for the given input using the [VertexAI async client definition](https://cloud.google.com/vertex-ai/docs/generative-ai/model-reference/gemini). Args: diff --git a/src/distilabel/llms/vllm.py b/src/distilabel/llms/vllm.py index e25d76278..be04a17ec 100644 --- a/src/distilabel/llms/vllm.py +++ b/src/distilabel/llms/vllm.py @@ -14,19 +14,19 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional -from pydantic import Field, PrivateAttr +from pydantic import Field, PrivateAttr, validate_call from distilabel.llms.base import LLM from distilabel.llms.chat_templates import CHATML_TEMPLATE from distilabel.llms.mixins import CudaDevicePlacementMixin +from distilabel.llms.typing import GenerateOutput from distilabel.mixins.runtime_parameters import RuntimeParameter +from distilabel.steps.tasks.typing import ChatType if TYPE_CHECKING: from transformers import PreTrainedTokenizer from vllm import LLM as _vLLM - from distilabel.llms.typing import GenerateOutput - from distilabel.steps.tasks.typing import ChatType SamplingParams = None @@ -112,9 +112,10 @@ def prepare_input(self, input: "ChatType") -> str: add_generation_prompt=True, # type: ignore ) - def generate( + @validate_call + def generate( # type: ignore self, - inputs: List["ChatType"], + inputs: List[ChatType], num_generations: int = 1, max_new_tokens: int = 128, frequency_penalty: float = 0.0, @@ -123,7 +124,7 @@ def generate( top_p: float = 1.0, top_k: int = -1, extra_sampling_params: Optional[Dict[str, Any]] = None, - ) -> List["GenerateOutput"]: + ) -> List[GenerateOutput]: """Generates `num_generations` responses for each input using the text generation pipeline. diff --git a/src/distilabel/steps/tasks/typing.py b/src/distilabel/steps/tasks/typing.py index 8304674f8..cbd6ffc09 100644 --- a/src/distilabel/steps/tasks/typing.py +++ b/src/distilabel/steps/tasks/typing.py @@ -12,7 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, TypedDict +from typing import List + +from typing_extensions import TypedDict class ChatItem(TypedDict): diff --git a/tests/unit/llms/test_ollama.py b/tests/unit/llms/test_ollama.py index e8cc10b53..1c64f1ed5 100644 --- a/tests/unit/llms/test_ollama.py +++ b/tests/unit/llms/test_ollama.py @@ -62,13 +62,11 @@ async def test_generate(self, mock_ollama: MagicMock) -> None: llm.generate( inputs=[ [ - [ - {"role": "system", "content": ""}, - { - "role": "user", - "content": "Lorem ipsum dolor sit amet, consectetur adipiscing elit.", - }, - ] + {"role": "system", "content": ""}, + { + "role": "user", + "content": "Lorem ipsum dolor sit amet, consectetur adipiscing elit.", + }, ] ] )