Skip to content

groq: add support for accessing reasoning output from Groq models #31662

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Jun 23, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 19 additions & 1 deletion libs/partners/groq/langchain_groq/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,12 @@ class ChatGroq(BaseChatModel):
Sampling temperature. Ranges from 0.0 to 1.0.
max_tokens: Optional[int]
Max number of tokens to generate.
reasoning_format: Optional[Literal["parsed", "raw", "hidden]]
The format for reasoning output.

- ``parsed``: Separates reasoning into a dedicated field while keeping the response concise.
- ``raw``: Includes reasoning within think tags in the content.
- ``hidden``: Returns only the final answer.
model_kwargs: Dict[str, Any]
Holds any model parameters valid for create call not
explicitly specified.
Expand Down Expand Up @@ -292,7 +298,7 @@ class Joke(BaseModel):
'system_fingerprint': 'fp_c5f20b5bb1',
'finish_reason': 'stop',
'logprobs': None}
"""
""" # noqa: E501

client: Any = Field(default=None, exclude=True) #: :meta private:
async_client: Any = Field(default=None, exclude=True) #: :meta private:
Expand All @@ -302,6 +308,13 @@ class Joke(BaseModel):
"""What sampling temperature to use."""
stop: Optional[Union[list[str], str]] = Field(default=None, alias="stop_sequences")
"""Default stop sequences."""
reasoning_format: Optional[Literal["parsed", "raw", "hidden"]] = None
"""The format for reasoning output.

- ``parsed``: Separates reasoning into a dedicated field while keeping the response concise.
- ``raw``: Includes reasoning within think tags in the content.
- ``hidden``: Returns only the final answer.
""" # noqa: E501
model_kwargs: dict[str, Any] = Field(default_factory=dict)
"""Holds any model parameters valid for `create` call not explicitly specified."""
groq_api_key: Optional[SecretStr] = Field(
Expand Down Expand Up @@ -606,6 +619,7 @@ def _default_params(self) -> dict[str, Any]:
"n": self.n,
"temperature": self.temperature,
"stop": self.stop,
"reasoning_format": self.reasoning_format,
**self.model_kwargs,
}
if self.max_tokens is not None:
Expand Down Expand Up @@ -1153,6 +1167,8 @@ def _convert_chunk_to_message_chunk(
if role == "user" or default_class == HumanMessageChunk:
return HumanMessageChunk(content=content)
elif role == "assistant" or default_class == AIMessageChunk:
if reasoning := _dict.get("reasoning"):
additional_kwargs["reasoning_content"] = reasoning
if usage := (chunk.get("x_groq") or {}).get("usage"):
input_tokens = usage.get("prompt_tokens", 0)
output_tokens = usage.get("completion_tokens", 0)
Expand Down Expand Up @@ -1196,6 +1212,8 @@ def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
elif role == "assistant":
content = _dict.get("content", "") or ""
additional_kwargs: dict = {}
if reasoning := _dict.get("reasoning"):
additional_kwargs["reasoning_content"] = reasoning
if function_call := _dict.get("function_call"):
additional_kwargs["function_call"] = dict(function_call)
tool_calls = []
Expand Down
54 changes: 53 additions & 1 deletion libs/partners/groq/tests/integration_tests/test_chat_models.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Test ChatGroq chat model."""

import json
from typing import Any, Optional
from typing import Any, Optional, cast

import pytest
from langchain_core.messages import (
Expand Down Expand Up @@ -212,6 +212,58 @@ async def test_agenerate_streaming() -> None:
assert generation.text == generation.message.content


#
# Test reasoning output
#
def test_reasoning_output_invoke() -> None:
"""Test reasoning output from ChatGroq with invoke."""
chat = ChatGroq(
model="deepseek-r1-distill-llama-70b",
reasoning_format="parsed",
)
message = [
SystemMessage(
content="You are a helpful assistant that translates English to French."
),
HumanMessage(content="I love programming."),
]
response = chat.invoke(message)
assert isinstance(response, AIMessage)
assert "reasoning_content" in response.additional_kwargs
assert isinstance(response.additional_kwargs["reasoning_content"], str)
assert len(response.additional_kwargs["reasoning_content"]) > 0


def test_reasoning_output_stream() -> None:
"""Test reasoning output from ChatGroq with stream."""
chat = ChatGroq(
model="deepseek-r1-distill-llama-70b",
reasoning_format="parsed",
)
message = [
SystemMessage(
content="You are a helpful assistant that translates English to French."
),
HumanMessage(content="I love programming."),
]

full_response: Optional[AIMessageChunk] = None
for token in chat.stream(message):
assert isinstance(token, AIMessageChunk)

if full_response is None:
full_response = token
else:
# Casting since adding results in a type error
full_response = cast(AIMessageChunk, full_response + token)

assert full_response is not None
assert isinstance(full_response, AIMessageChunk)
assert "reasoning_content" in full_response.additional_kwargs
assert isinstance(full_response.additional_kwargs["reasoning_content"], str)
assert len(full_response.additional_kwargs["reasoning_content"]) > 0


#
# Misc tests
#
Expand Down
Loading