Skip to content

Commit

Permalink
chore: fix Hugging Face components for mypy 1.15.0 (#8822)
Browse files Browse the repository at this point in the history
* chore: fix Hugging Face components for mypy 1.15.0

* small fixes

* fix test

* rm print

* use cast and be more permissive
  • Loading branch information
anakin87 authored Feb 6, 2025
1 parent e7c6d14 commit 1f25794
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 13 deletions.
25 changes: 18 additions & 7 deletions haystack/components/generators/chat/hugging_face_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

with LazyImport(message="Run 'pip install \"huggingface_hub[inference]>=0.27.0\"'") as huggingface_hub_import:
from huggingface_hub import (
ChatCompletionInputFunctionDefinition,
ChatCompletionInputTool,
ChatCompletionOutput,
ChatCompletionStreamOutput,
Expand Down Expand Up @@ -255,8 +256,15 @@ def run(

hf_tools = None
if tools:
hf_tools = [{"type": "function", "function": {**t.tool_spec}} for t in tools]

hf_tools = [
ChatCompletionInputTool(
function=ChatCompletionInputFunctionDefinition(
name=tool.name, description=tool.description, arguments=tool.parameters
),
type="function",
)
for tool in tools
]
return self._run_non_streaming(formatted_messages, generation_kwargs, hf_tools)

def _run_streaming(
Expand All @@ -278,13 +286,12 @@ def _run_streaming(
# see https://huggingface.co/docs/huggingface_hub/package_reference/inference_client#huggingface_hub.InferenceClient.chat_completion.n
choice = chunk.choices[0]

text = choice.delta.content
if text:
generated_text += text
text = choice.delta.content or ""
generated_text += text

finish_reason = choice.finish_reason

meta = {}
meta: Dict[str, Any] = {}
if finish_reason:
meta["finish_reason"] = finish_reason

Expand Down Expand Up @@ -336,7 +343,11 @@ def _run_non_streaming(
)
tool_calls.append(tool_call)

meta = {"model": self._client.model, "finish_reason": choice.finish_reason, "index": choice.index}
meta: Dict[str, Any] = {
"model": self._client.model,
"finish_reason": choice.finish_reason,
"index": choice.index,
}

usage = {"prompt_tokens": 0, "completion_tokens": 0}
if api_chat_output.usage:
Expand Down
2 changes: 1 addition & 1 deletion haystack/components/generators/chat/hugging_face_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def __init__( # pylint: disable=too-many-positional-arguments
elif isinstance(huggingface_pipeline_kwargs["model"], str):
task = model_info(
huggingface_pipeline_kwargs["model"], token=huggingface_pipeline_kwargs["token"]
).pipeline_tag
).pipeline_tag # type: ignore[assignment] # we'll check below if task is in supported tasks

if task not in PIPELINE_SUPPORTED_TASKS:
raise ValueError(
Expand Down
9 changes: 5 additions & 4 deletions haystack/components/generators/hugging_face_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from dataclasses import asdict
from datetime import datetime
from typing import Any, Callable, Dict, Iterable, List, Optional, Union
from typing import Any, Callable, Dict, Iterable, List, Optional, Union, cast

from haystack import component, default_from_dict, default_to_dict, logging
from haystack.dataclasses import StreamingChunk
Expand All @@ -17,8 +17,8 @@
from huggingface_hub import (
InferenceClient,
TextGenerationOutput,
TextGenerationOutputToken,
TextGenerationStreamOutput,
TextGenerationStreamOutputToken,
)


Expand Down Expand Up @@ -212,7 +212,8 @@ def run(
if streaming_callback is not None:
return self._stream_and_build_response(hf_output, streaming_callback)

return self._build_non_streaming_response(hf_output)
# mypy doesn't know that hf_output is a TextGenerationOutput, so we cast it
return self._build_non_streaming_response(cast(TextGenerationOutput, hf_output))

def _stream_and_build_response(
self, hf_output: Iterable["TextGenerationStreamOutput"], streaming_callback: Callable[[StreamingChunk], None]
Expand All @@ -221,7 +222,7 @@ def _stream_and_build_response(
first_chunk_time = None

for chunk in hf_output:
token: TextGenerationOutputToken = chunk.token
token: TextGenerationStreamOutputToken = chunk.token
if token.special:
continue

Expand Down
3 changes: 2 additions & 1 deletion test/components/generators/test_hugging_face_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import pytest
from huggingface_hub import (
TextGenerationOutput,
TextGenerationOutputToken,
TextGenerationStreamOutput,
TextGenerationStreamOutputStreamDetails,
Expand All @@ -30,7 +31,7 @@ def mock_check_valid_model():
@pytest.fixture
def mock_text_generation():
with patch("huggingface_hub.InferenceClient.text_generation", autospec=True) as mock_text_generation:
mock_response = Mock()
mock_response = Mock(spec=TextGenerationOutput)
mock_response.generated_text = "I'm fine, thanks."
details = Mock()
details.finish_reason = MagicMock(field1="value")
Expand Down

0 comments on commit 1f25794

Please sign in to comment.