Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
161 changes: 135 additions & 26 deletions docs/user-guides/configuration-guide/custom-initialization.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,56 +37,65 @@ def init(app: LLMRails):

## Custom LLM Provider Registration

To register a custom LLM provider, you need to create a class that inherits from `BaseLanguageModel` and register it using `register_llm_provider`.
NeMo Guardrails supports two types of custom LLM providers:
1. **Text Completion Models** (`BaseLLM`) - For models that work with string prompts
2. **Chat Models** (`BaseChatModel`) - For models that work with message-based conversations

It is important to implement the following methods:
### Custom Text Completion LLM (BaseLLM)

**Required**:
To register a custom text completion LLM provider, create a class that inherits from `BaseLLM` and register it using `register_llm_provider`.

- `_call`
- `_llm_type`
**Required methods:**
- `_call` - Synchronous text completion
- `_llm_type` - Returns the LLM type identifier

**Optional**:

- `_acall`
- `_astream`
- `_stream`
- `_identifying_params`

In other words, to create your custom LLM provider, you need to implement the following interface methods: `_call`, `_llm_type`, and optionally `_acall`, `_astream`, `_stream`, and `_identifying_params`. Here's how you can do it:
**Optional methods:**
- `_acall` - Asynchronous text completion (recommended)
- `_stream` - Streaming text completion
- `_astream` - Async streaming text completion
- `_identifying_params` - Returns parameters for model identification

```python
from typing import Any, Iterator, List, Optional

from langchain.base_language import BaseLanguageModel
from langchain_core.callbacks.manager import (
CallbackManagerForLLMRun,
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models import BaseLLM
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

syntax: Import BaseLLM from langchain_core.language_models (not langchain_core.language_models.llms) to match LangChain v1 canonical import paths used throughout the codebase

from langchain_core.outputs import GenerationChunk

from nemoguardrails.llm.providers import register_llm_provider


class MyCustomLLM(BaseLanguageModel):
class MyCustomTextLLM(BaseLLM):
"""Custom text completion LLM."""

@property
def _llm_type(self) -> str:
return "custom_text_llm"

def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs,
**kwargs: Any,
) -> str:
pass
"""Synchronous text completion."""
# Your implementation here
return "Generated text response"

async def _acall(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs,
**kwargs: Any,
) -> str:
pass
"""Asynchronous text completion (recommended)."""
# Your async implementation here
return "Generated text response"

def _stream(
self,
Expand All @@ -95,22 +104,122 @@ class MyCustomLLM(BaseLanguageModel):
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[GenerationChunk]:
pass
"""Optional: Streaming text completion."""
# Yield chunks of text
yield GenerationChunk(text="chunk1")
yield GenerationChunk(text="chunk2")


register_llm_provider("custom_text_llm", MyCustomTextLLM)
```

### Custom Chat Model (BaseChatModel)

To register a custom chat model, create a class that inherits from `BaseChatModel` and register it using `register_chat_provider`.

**Required methods:**
- `_generate` - Synchronous chat completion
- `_llm_type` - Returns the LLM type identifier

**Optional methods:**
- `_agenerate` - Asynchronous chat completion (recommended)
- `_stream` - Streaming chat completion
- `_astream` - Async streaming chat completion

```python
from typing import Any, Iterator, List, Optional

from langchain_core.callbacks.manager import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models import BaseChatModel
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

syntax: Import BaseChatModel from langchain_core.language_models (not langchain_core.language_models.chat_models) to match LangChain v1 canonical import paths used throughout the codebase

from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult

from nemoguardrails.llm.providers import register_chat_provider


class MyCustomChatModel(BaseChatModel):
"""Custom chat model."""

# rest of the implementation
...
@property
def _llm_type(self) -> str:
return "custom_chat_model"

register_llm_provider("custom_llm", MyCustomLLM)
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
"""Synchronous chat completion."""
# Convert messages to your model's format and generate response
response_text = "Generated chat response"

message = AIMessage(content=response_text)
generation = ChatGeneration(message=message)
return ChatResult(generations=[generation])

async def _agenerate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
"""Asynchronous chat completion (recommended)."""
# Your async implementation
response_text = "Generated chat response"

message = AIMessage(content=response_text)
generation = ChatGeneration(message=message)
return ChatResult(generations=[generation])

def _stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
"""Optional: Streaming chat completion."""
# Yield chunks
chunk = ChatGenerationChunk(message=AIMessageChunk(content="chunk1"))
yield chunk


register_chat_provider("custom_chat_model", MyCustomChatModel)
```

You can then use the custom LLM provider in your configuration:
### Using Custom LLM Providers

After registering your custom provider, you can use it in your configuration:

```yaml
models:
- type: main
engine: custom_llm
engine: custom_text_llm # or custom_chat_model
```

### Important Notes

1. **Import from langchain-core:** Always import base classes from `langchain_core.language_models`:
```python
from langchain_core.language_models import BaseLLM, BaseChatModel
```

2. **Implement async methods:** For better performance, always implement `_acall` (for BaseLLM) or `_agenerate` (for BaseChatModel).

3. **Choose the right base class:**
- Use `BaseLLM` for text completion models (prompt → text)
- Use `BaseChatModel` for chat models (messages → message)

4. **Registration functions:**
- Use `register_llm_provider()` for `BaseLLM` subclasses
- Use `register_chat_provider()` for `BaseChatModel` subclasses

## Custom Embedding Provider Registration

You can also register a custom embedding provider by using the `LLMRails.register_embedding_provider` function.
Expand Down
2 changes: 2 additions & 0 deletions docs/user-guides/python-api.md
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,8 @@ For convenience, this toolkit also includes a selection of LangChain tools, wrap

### Chains as Actions

> **⚠️ DEPRECATED**: Chain support is deprecated and will be removed in a future release. Please use [Runnable](https://python.langchain.com/docs/expression_language/) instead. See the [Runnable as Action Guide](langchain/runnable-as-action/README.md) for examples.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: Link to migration guide is incorrect - should be ../langchain/runnable-as-action/README.md (add ../ prefix) since this file is in user-guides/ and the target is in user-guides/langchain/runnable-as-action/


You can register a Langchain chain as an action using the [LLMRails.register_action](../api/nemoguardrails.rails.llm.llmrails.md#method-llmrailsregister_action) method:

```python
Expand Down
2 changes: 1 addition & 1 deletion examples/configs/rag/custom_rag_output_rails/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from langchain.prompts import PromptTemplate
from langchain_core.language_models.llms import BaseLLM
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

syntax: import path for BaseLLM changed from langchain_core.language_models.llms to langchain_core.language_models in LangChain 1.x. Use from langchain_core.language_models import BaseLLM for consistency with other files in this PR

from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import PromptTemplate

from nemoguardrails import LLMRails
from nemoguardrails.actions.actions import ActionResult
Expand Down
22 changes: 18 additions & 4 deletions examples/configs/rag/multi_kb/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,24 @@
import pandas as pd
import torch
from gpt4pandas import GPT4Pandas
from langchain.chains import RetrievalQA
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.text_splitter import CharacterTextSplitter
from langchain.vectorstores import FAISS

try:
from langchain.chains import RetrievalQA
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.text_splitter import CharacterTextSplitter
from langchain.vectorstores import FAISS
except ImportError:
try:
from langchain_classic.chains import RetrievalQA
from langchain_classic.embeddings import HuggingFaceEmbeddings
from langchain_classic.text_splitter import CharacterTextSplitter
from langchain_classic.vectorstores import FAISS
except ImportError as e:
raise ImportError(
"Failed to import from langchain. If you're using LangChain >= 1.0.0, "
"please install langchain-classic: pip install langchain-classic"
) from e
Comment on lines +26 to +40
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: The nested try-except blocks will shadow the original ImportError if langchain_classic also fails to import a different subset of classes. The error message on line 38-39 will incorrectly suggest installing langchain-classic even if RetrievalQA imports but HuggingFaceEmbeddings doesn't (for example). Use raise ... from None or re-raise the first ImportError.


from langchain_core.language_models.llms import BaseLLM
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline

Expand Down
7 changes: 3 additions & 4 deletions examples/configs/rag/multi_kb/tabular_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import asyncio
from typing import Any, Dict, List, Optional

from langchain.callbacks.manager import (
from langchain_core.callbacks.manager import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain.llms.base import LLM
from langchain_core.language_models.llms import BaseLLM


def query_tabular_data(usr_query: str, gpt: any, raw_data_frame: any):
Expand Down Expand Up @@ -58,7 +57,7 @@ def query_tabular_data(usr_query: str, gpt: any, raw_data_frame: any):
return out, d2.to_string()


class TabularLLM(LLM):
class TabularLLM(BaseLLM):
"""LLM wrapping for GPT4Pandas."""

model: str = ""
Expand Down
22 changes: 18 additions & 4 deletions examples/configs/rag/pinecone/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,24 @@
from typing import Optional

import pinecone
from langchain.chains import RetrievalQA
from langchain.docstore.document import Document
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.vectorstores import Pinecone

try:
from langchain.chains import RetrievalQA
from langchain.docstore.document import Document
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.vectorstores import Pinecone
except ImportError:
try:
from langchain_classic.chains import RetrievalQA
from langchain_classic.docstore.document import Document
from langchain_classic.embeddings.openai import OpenAIEmbeddings
from langchain_classic.vectorstores import Pinecone
except ImportError as e:
raise ImportError(
"Failed to import from langchain. If you're using LangChain >= 1.0.0, "
"please install langchain-classic: pip install langchain-classic"
) from e

from langchain_core.language_models.llms import BaseLLM
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

syntax: Import path langchain_core.language_models.llms is deprecated. Use from langchain_core.language_models import BaseLLM instead


from nemoguardrails import LLMRails
Expand Down
14 changes: 12 additions & 2 deletions examples/scripts/langchain/experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,18 @@

import os

from langchain.chains import LLMMathChain
from langchain.prompts import ChatPromptTemplate
try:
from langchain.chains import LLMMathChain
except ImportError:
try:
from langchain_classic.chains import LLMMathChain
except ImportError as e:
raise ImportError(
"Failed to import LLMMathChain. If you're using LangChain >= 1.0.0, "
"please install langchain-classic: pip install langchain-classic"
Comment on lines +24 to +26
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: The error message instructs users to install langchain-classic, but doesn't mention that this is only needed for LangChain >= 1.0.0. Users on LangChain 0.x who encounter this error (due to a different import issue) might be confused. Should the error message clarify that langchain-classic is specifically for LangChain 1.x users, or verify the installed LangChain version before suggesting the installation?

) from e

from langchain_core.prompts import ChatPromptTemplate
from langchain_core.tools import Tool
from langchain_openai.chat_models import ChatOpenAI
from pydantic import BaseModel, Field
Expand Down
23 changes: 0 additions & 23 deletions nemoguardrails/actions/action_dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,10 @@
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union, cast

from langchain.chains.base import Chain
from langchain_core.runnables import Runnable

from nemoguardrails import utils
from nemoguardrails.actions.llm.utils import LLMCallException
from nemoguardrails.logging.callbacks import logging_callbacks

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -228,27 +226,6 @@ async def execute_action(
f"Synchronous action `{action_name}` has been called."
)

elif isinstance(fn, Chain):
try:
chain = fn

# For chains with only one output key, we use the `arun` function
# to return directly the result.
if len(chain.output_keys) == 1:
result = await chain.arun(
**params, callbacks=logging_callbacks
)
else:
# Otherwise, we return the dict with the output keys.
result = await chain.acall(
inputs=params,
return_only_outputs=True,
callbacks=logging_callbacks,
)
except NotImplementedError:
# Not ideal, but for now we fall back to sync execution
# if the async is not available
result = fn.run(**params)
elif isinstance(fn, Runnable):
# If it's a Runnable, we invoke it as well
runnable = fn
Expand Down
3 changes: 1 addition & 2 deletions nemoguardrails/actions/llm/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,7 @@

from jinja2 import meta
from jinja2.sandbox import SandboxedEnvironment
from langchain_core.language_models import BaseChatModel
from langchain_core.language_models.llms import BaseLLM
from langchain_core.language_models import BaseChatModel, BaseLLM

from nemoguardrails.actions.actions import ActionResult, action
from nemoguardrails.actions.llm.utils import (
Expand Down
Loading