-
Notifications
You must be signed in to change notification settings - Fork 558
feat: support langchain v1 #1472
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: chore/remove-deprecated-llm-params
Are you sure you want to change the base?
Changes from all commits
c8c18b0
9558458
12a1be1
a9f9e12
89ee95b
11ef083
388a0e4
14357dc
56cf606
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -37,56 +37,65 @@ def init(app: LLMRails): | |
|
|
||
| ## Custom LLM Provider Registration | ||
|
|
||
| To register a custom LLM provider, you need to create a class that inherits from `BaseLanguageModel` and register it using `register_llm_provider`. | ||
| NeMo Guardrails supports two types of custom LLM providers: | ||
| 1. **Text Completion Models** (`BaseLLM`) - For models that work with string prompts | ||
| 2. **Chat Models** (`BaseChatModel`) - For models that work with message-based conversations | ||
|
|
||
| It is important to implement the following methods: | ||
| ### Custom Text Completion LLM (BaseLLM) | ||
|
|
||
| **Required**: | ||
| To register a custom text completion LLM provider, create a class that inherits from `BaseLLM` and register it using `register_llm_provider`. | ||
|
|
||
| - `_call` | ||
| - `_llm_type` | ||
| **Required methods:** | ||
| - `_call` - Synchronous text completion | ||
| - `_llm_type` - Returns the LLM type identifier | ||
|
|
||
| **Optional**: | ||
|
|
||
| - `_acall` | ||
| - `_astream` | ||
| - `_stream` | ||
| - `_identifying_params` | ||
|
|
||
| In other words, to create your custom LLM provider, you need to implement the following interface methods: `_call`, `_llm_type`, and optionally `_acall`, `_astream`, `_stream`, and `_identifying_params`. Here's how you can do it: | ||
| **Optional methods:** | ||
| - `_acall` - Asynchronous text completion (recommended) | ||
| - `_stream` - Streaming text completion | ||
| - `_astream` - Async streaming text completion | ||
| - `_identifying_params` - Returns parameters for model identification | ||
|
|
||
| ```python | ||
| from typing import Any, Iterator, List, Optional | ||
|
|
||
| from langchain.base_language import BaseLanguageModel | ||
| from langchain_core.callbacks.manager import ( | ||
| CallbackManagerForLLMRun, | ||
| AsyncCallbackManagerForLLMRun, | ||
| CallbackManagerForLLMRun, | ||
| ) | ||
| from langchain_core.language_models import BaseLLM | ||
| from langchain_core.outputs import GenerationChunk | ||
|
|
||
| from nemoguardrails.llm.providers import register_llm_provider | ||
|
|
||
|
|
||
| class MyCustomLLM(BaseLanguageModel): | ||
| class MyCustomTextLLM(BaseLLM): | ||
| """Custom text completion LLM.""" | ||
|
|
||
| @property | ||
| def _llm_type(self) -> str: | ||
| return "custom_text_llm" | ||
|
|
||
| def _call( | ||
| self, | ||
| prompt: str, | ||
| stop: Optional[List[str]] = None, | ||
| run_manager: Optional[CallbackManagerForLLMRun] = None, | ||
| **kwargs, | ||
| **kwargs: Any, | ||
| ) -> str: | ||
| pass | ||
| """Synchronous text completion.""" | ||
| # Your implementation here | ||
| return "Generated text response" | ||
|
|
||
| async def _acall( | ||
| self, | ||
| prompt: str, | ||
| stop: Optional[List[str]] = None, | ||
| run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, | ||
| **kwargs, | ||
| **kwargs: Any, | ||
| ) -> str: | ||
| pass | ||
| """Asynchronous text completion (recommended).""" | ||
| # Your async implementation here | ||
| return "Generated text response" | ||
|
|
||
| def _stream( | ||
| self, | ||
|
|
@@ -95,22 +104,122 @@ class MyCustomLLM(BaseLanguageModel): | |
| run_manager: Optional[CallbackManagerForLLMRun] = None, | ||
| **kwargs: Any, | ||
| ) -> Iterator[GenerationChunk]: | ||
| pass | ||
| """Optional: Streaming text completion.""" | ||
| # Yield chunks of text | ||
| yield GenerationChunk(text="chunk1") | ||
| yield GenerationChunk(text="chunk2") | ||
|
|
||
|
|
||
| register_llm_provider("custom_text_llm", MyCustomTextLLM) | ||
| ``` | ||
|
|
||
| ### Custom Chat Model (BaseChatModel) | ||
|
|
||
| To register a custom chat model, create a class that inherits from `BaseChatModel` and register it using `register_chat_provider`. | ||
|
|
||
| **Required methods:** | ||
| - `_generate` - Synchronous chat completion | ||
| - `_llm_type` - Returns the LLM type identifier | ||
|
|
||
| **Optional methods:** | ||
| - `_agenerate` - Asynchronous chat completion (recommended) | ||
| - `_stream` - Streaming chat completion | ||
| - `_astream` - Async streaming chat completion | ||
|
|
||
| ```python | ||
| from typing import Any, Iterator, List, Optional | ||
|
|
||
| from langchain_core.callbacks.manager import ( | ||
| AsyncCallbackManagerForLLMRun, | ||
| CallbackManagerForLLMRun, | ||
| ) | ||
| from langchain_core.language_models import BaseChatModel | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. syntax: Import |
||
| from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage | ||
| from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult | ||
|
|
||
| from nemoguardrails.llm.providers import register_chat_provider | ||
|
|
||
|
|
||
| class MyCustomChatModel(BaseChatModel): | ||
| """Custom chat model.""" | ||
|
|
||
| # rest of the implementation | ||
| ... | ||
| @property | ||
| def _llm_type(self) -> str: | ||
| return "custom_chat_model" | ||
|
|
||
| register_llm_provider("custom_llm", MyCustomLLM) | ||
| def _generate( | ||
| self, | ||
| messages: List[BaseMessage], | ||
| stop: Optional[List[str]] = None, | ||
| run_manager: Optional[CallbackManagerForLLMRun] = None, | ||
| **kwargs: Any, | ||
| ) -> ChatResult: | ||
| """Synchronous chat completion.""" | ||
| # Convert messages to your model's format and generate response | ||
| response_text = "Generated chat response" | ||
|
|
||
| message = AIMessage(content=response_text) | ||
| generation = ChatGeneration(message=message) | ||
| return ChatResult(generations=[generation]) | ||
|
|
||
| async def _agenerate( | ||
| self, | ||
| messages: List[BaseMessage], | ||
| stop: Optional[List[str]] = None, | ||
| run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, | ||
| **kwargs: Any, | ||
| ) -> ChatResult: | ||
| """Asynchronous chat completion (recommended).""" | ||
| # Your async implementation | ||
| response_text = "Generated chat response" | ||
|
|
||
| message = AIMessage(content=response_text) | ||
| generation = ChatGeneration(message=message) | ||
| return ChatResult(generations=[generation]) | ||
|
|
||
| def _stream( | ||
| self, | ||
| messages: List[BaseMessage], | ||
| stop: Optional[List[str]] = None, | ||
| run_manager: Optional[CallbackManagerForLLMRun] = None, | ||
| **kwargs: Any, | ||
| ) -> Iterator[ChatGenerationChunk]: | ||
| """Optional: Streaming chat completion.""" | ||
| # Yield chunks | ||
| chunk = ChatGenerationChunk(message=AIMessageChunk(content="chunk1")) | ||
| yield chunk | ||
|
|
||
|
|
||
| register_chat_provider("custom_chat_model", MyCustomChatModel) | ||
| ``` | ||
|
|
||
| You can then use the custom LLM provider in your configuration: | ||
| ### Using Custom LLM Providers | ||
|
|
||
| After registering your custom provider, you can use it in your configuration: | ||
|
|
||
| ```yaml | ||
| models: | ||
| - type: main | ||
| engine: custom_llm | ||
| engine: custom_text_llm # or custom_chat_model | ||
| ``` | ||
|
|
||
| ### Important Notes | ||
|
|
||
| 1. **Import from langchain-core:** Always import base classes from `langchain_core.language_models`: | ||
| ```python | ||
| from langchain_core.language_models import BaseLLM, BaseChatModel | ||
| ``` | ||
|
|
||
| 2. **Implement async methods:** For better performance, always implement `_acall` (for BaseLLM) or `_agenerate` (for BaseChatModel). | ||
|
|
||
| 3. **Choose the right base class:** | ||
| - Use `BaseLLM` for text completion models (prompt → text) | ||
| - Use `BaseChatModel` for chat models (messages → message) | ||
|
|
||
| 4. **Registration functions:** | ||
| - Use `register_llm_provider()` for `BaseLLM` subclasses | ||
| - Use `register_chat_provider()` for `BaseChatModel` subclasses | ||
|
|
||
| ## Custom Embedding Provider Registration | ||
|
|
||
| You can also register a custom embedding provider by using the `LLMRails.register_embedding_provider` function. | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -132,6 +132,8 @@ For convenience, this toolkit also includes a selection of LangChain tools, wrap | |
|
|
||
| ### Chains as Actions | ||
|
|
||
| > **⚠️ DEPRECATED**: Chain support is deprecated and will be removed in a future release. Please use [Runnable](https://python.langchain.com/docs/expression_language/) instead. See the [Runnable as Action Guide](langchain/runnable-as-action/README.md) for examples. | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. style: Link to migration guide is incorrect - should be |
||
|
|
||
| You can register a Langchain chain as an action using the [LLMRails.register_action](../api/nemoguardrails.rails.llm.llmrails.md#method-llmrailsregister_action) method: | ||
|
|
||
| ```python | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -13,9 +13,9 @@ | |
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| from langchain.prompts import PromptTemplate | ||
| from langchain_core.language_models.llms import BaseLLM | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. syntax: import path for BaseLLM changed from |
||
| from langchain_core.output_parsers import StrOutputParser | ||
| from langchain_core.prompts import PromptTemplate | ||
|
|
||
| from nemoguardrails import LLMRails | ||
| from nemoguardrails.actions.actions import ActionResult | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -21,10 +21,24 @@ | |
| import pandas as pd | ||
| import torch | ||
| from gpt4pandas import GPT4Pandas | ||
| from langchain.chains import RetrievalQA | ||
| from langchain.embeddings import HuggingFaceEmbeddings | ||
| from langchain.text_splitter import CharacterTextSplitter | ||
| from langchain.vectorstores import FAISS | ||
|
|
||
| try: | ||
| from langchain.chains import RetrievalQA | ||
| from langchain.embeddings import HuggingFaceEmbeddings | ||
| from langchain.text_splitter import CharacterTextSplitter | ||
| from langchain.vectorstores import FAISS | ||
| except ImportError: | ||
| try: | ||
| from langchain_classic.chains import RetrievalQA | ||
| from langchain_classic.embeddings import HuggingFaceEmbeddings | ||
| from langchain_classic.text_splitter import CharacterTextSplitter | ||
| from langchain_classic.vectorstores import FAISS | ||
| except ImportError as e: | ||
| raise ImportError( | ||
| "Failed to import from langchain. If you're using LangChain >= 1.0.0, " | ||
| "please install langchain-classic: pip install langchain-classic" | ||
| ) from e | ||
|
Comment on lines
+26
to
+40
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. logic: The nested try-except blocks will shadow the original ImportError if |
||
|
|
||
| from langchain_core.language_models.llms import BaseLLM | ||
| from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -18,10 +18,24 @@ | |
| from typing import Optional | ||
|
|
||
| import pinecone | ||
| from langchain.chains import RetrievalQA | ||
| from langchain.docstore.document import Document | ||
| from langchain.embeddings.openai import OpenAIEmbeddings | ||
| from langchain.vectorstores import Pinecone | ||
|
|
||
| try: | ||
| from langchain.chains import RetrievalQA | ||
| from langchain.docstore.document import Document | ||
| from langchain.embeddings.openai import OpenAIEmbeddings | ||
| from langchain.vectorstores import Pinecone | ||
| except ImportError: | ||
| try: | ||
| from langchain_classic.chains import RetrievalQA | ||
| from langchain_classic.docstore.document import Document | ||
| from langchain_classic.embeddings.openai import OpenAIEmbeddings | ||
| from langchain_classic.vectorstores import Pinecone | ||
| except ImportError as e: | ||
| raise ImportError( | ||
| "Failed to import from langchain. If you're using LangChain >= 1.0.0, " | ||
| "please install langchain-classic: pip install langchain-classic" | ||
| ) from e | ||
|
|
||
| from langchain_core.language_models.llms import BaseLLM | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. syntax: Import path |
||
|
|
||
| from nemoguardrails import LLMRails | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -15,8 +15,18 @@ | |
|
|
||
| import os | ||
|
|
||
| from langchain.chains import LLMMathChain | ||
| from langchain.prompts import ChatPromptTemplate | ||
| try: | ||
| from langchain.chains import LLMMathChain | ||
| except ImportError: | ||
| try: | ||
| from langchain_classic.chains import LLMMathChain | ||
| except ImportError as e: | ||
| raise ImportError( | ||
| "Failed to import LLMMathChain. If you're using LangChain >= 1.0.0, " | ||
| "please install langchain-classic: pip install langchain-classic" | ||
|
Comment on lines
+24
to
+26
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. style: The error message instructs users to install |
||
| ) from e | ||
|
|
||
| from langchain_core.prompts import ChatPromptTemplate | ||
| from langchain_core.tools import Tool | ||
| from langchain_openai.chat_models import ChatOpenAI | ||
| from pydantic import BaseModel, Field | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
syntax: Import
BaseLLMfromlangchain_core.language_models(notlangchain_core.language_models.llms) to match LangChain v1 canonical import paths used throughout the codebase