Skip to content

Commit 38d969c

Browse files
authored
Update LangChain LLMs and Add Chat Models Support (#953)
2 parents 0a9c5d8 + ecbc0c8 commit 38d969c

29 files changed

+2654
-1076
lines changed

THIRD_PARTY_LICENSES.txt

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,18 @@ langchain
157157
* Source code: https://github.com/langchain-ai/langchain
158158
* Project home: https://www.langchain.com/
159159

160+
langchain-community
161+
* Copyright (c) 2023 LangChain, Inc.
162+
* License: MIT license
163+
* Source code: https://github.com/langchain-ai/langchain/tree/master/libs/community
164+
* Project home: https://github.com/langchain-ai/langchain/tree/master/libs/community
165+
166+
langchain-openai
167+
* Copyright (c) 2023 LangChain, Inc.
168+
* License: MIT license
169+
* Source code: https://github.com/langchain-ai/langchain/tree/master/libs/partners/openai
170+
* Project home: https://github.com/langchain-ai/langchain/tree/master/libs/partners/openai
171+
160172
lightgbm
161173
* Copyright (c) 2023 Microsoft Corporation
162174
* License: MIT license

ads/llm/__init__.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,16 @@
66

77
try:
88
import langchain
9-
from ads.llm.langchain.plugins.llm_gen_ai import GenerativeAI
10-
from ads.llm.langchain.plugins.llm_md import ModelDeploymentTGI
11-
from ads.llm.langchain.plugins.llm_md import ModelDeploymentVLLM
12-
from ads.llm.langchain.plugins.embeddings import GenerativeAIEmbeddings
9+
from ads.llm.langchain.plugins.llms.oci_data_science_model_deployment_endpoint import (
10+
OCIModelDeploymentVLLM,
11+
OCIModelDeploymentTGI,
12+
)
13+
from ads.llm.langchain.plugins.chat_models.oci_data_science import (
14+
ChatOCIModelDeployment,
15+
ChatOCIModelDeploymentVLLM,
16+
ChatOCIModelDeploymentTGI,
17+
)
18+
from ads.llm.chat_template import ChatTemplates
1319
except ImportError as ex:
1420
if ex.name == "langchain":
1521
raise ImportError(

ads/llm/chat_template.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
#!/usr/bin/env python
2+
# -*- coding: utf-8 -*--
3+
4+
# Copyright (c) 2023 Oracle and/or its affiliates.
5+
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
6+
7+
8+
import os
9+
10+
11+
class ChatTemplates:
12+
"""Contains chat templates."""
13+
14+
@staticmethod
15+
def _read_template(filename):
16+
with open(
17+
os.path.join(os.path.dirname(__file__), "templates", filename),
18+
mode="r",
19+
encoding="utf-8",
20+
) as f:
21+
return f.read()
22+
23+
@staticmethod
24+
def mistral():
25+
"""Chat template for auto tool calling with Mistral model deploy with vLLM."""
26+
return ChatTemplates._read_template("tool_chat_template_mistral_parallel.jinja")
27+
28+
@staticmethod
29+
def hermes():
30+
"""Chat template for auto tool calling with Hermes model deploy with vLLM."""
31+
return ChatTemplates._read_template("tool_chat_template_hermes.jinja")

ads/llm/guardrails/base.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from typing import Any, List, Dict, Tuple
1515
from langchain.schema.prompt import PromptValue
1616
from langchain.tools.base import BaseTool, ToolException
17-
from langchain.pydantic_v1 import BaseModel, root_validator
17+
from pydantic import BaseModel, model_validator
1818

1919

2020
class RunInfo(BaseModel):
@@ -190,7 +190,8 @@ class Config:
190190
This is used by the ``apply_filter()`` method.
191191
"""
192192

193-
@root_validator
193+
@model_validator(mode="before")
194+
@classmethod
194195
def default_name(cls, values):
195196
"""Sets the default name of the guardrail."""
196197
if not values.get("name"):

ads/llm/guardrails/huggingface.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77

88
import evaluate
9-
from langchain.pydantic_v1 import root_validator
9+
from pydantic.v1 import root_validator
1010
from .base import Guardrail
1111

1212

ads/llm/langchain/plugins/base.py

Lines changed: 0 additions & 118 deletions
This file was deleted.
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
#!/usr/bin/env python
2+
# -*- coding: utf-8 -*--
3+
4+
# Copyright (c) 2023 Oracle and/or its affiliates.
5+
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/

0 commit comments

Comments
 (0)