Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions haystack/components/generators/hugging_face_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,10 @@ def run(
replies = [o["generated_text"] for o in output if "generated_text" in o]

if self.stop_words:
# the output of the pipeline includes the stop word
replies = [reply.replace(stop_word, "").rstrip() for reply in replies for stop_word in self.stop_words]
# The output of the pipeline includes the stop word. Strip each stop word from each
# reply in sequence — the previous double-loop comprehension was a cross-product that
# produced N*M replies (half still containing a stop word) instead of N. See #11409.
for stop_word in self.stop_words:
replies = [reply.replace(stop_word, "").rstrip() for reply in replies]

return {"replies": replies}
28 changes: 28 additions & 0 deletions test/components/generators/test_hugging_face_local_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,6 +420,34 @@ def test_run_stop_words_removal(self):
results = generator.run(prompt="irrelevant")
assert results == {"replies": ["Hello"]}

def test_run_stop_words_removal_with_multiple_stop_words(self):
"""Regression for #11409: with N replies and M stop words, the cross-product comprehension
produced N*M replies (half still containing a stop word). The result must stay at N replies,
each with every stop word removed."""
generator = HuggingFaceLocalGenerator(
model="Qwen/Qwen3-0.6B", task="text-generation", stop_words=["STOP", "END"]
)
generator.pipeline = Mock(
return_value=[
{"generated_text": "Paris is the capital. STOP"},
{"generated_text": "France is in Europe. END"},
]
)
generator.stopping_criteria_list = Mock()
results = generator.run(prompt="irrelevant")
assert results == {"replies": ["Paris is the capital.", "France is in Europe."]}

def test_run_stop_words_removal_all_stop_words_removed_from_each_reply(self):
"""Every stop word is removed from every reply, not just the first matching one."""
generator = HuggingFaceLocalGenerator(
model="Qwen/Qwen3-0.6B", task="text-generation", stop_words=["STOP", "END"]
)
# Reply contains BOTH stop words
generator.pipeline = Mock(return_value=[{"generated_text": "Hello STOP world END"}])
generator.stopping_criteria_list = Mock()
results = generator.run(prompt="irrelevant")
assert results == {"replies": ["Hello world"]}

@pytest.mark.integration
def test_stop_words_criteria_using_hf_tokenizer(self):
"""
Expand Down