Skip to content

Commit 0763052

Browse files
committed
refactor: get rid of langchain fully
1 parent d10c036 commit 0763052

File tree

15 files changed

+692
-1068
lines changed

15 files changed

+692
-1068
lines changed

chatbot/bot/conversation/conversation_retrieval.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
from asyncio import get_event_loop
22
from typing import Any, List, Tuple
33

4+
from entities.document import Document
45
from helpers.log import get_logger
5-
from langchain_core.documents import Document
66

77
from bot.client.lama_cpp_client import LamaCppClient
88
from bot.conversation.ctx_strategy import AsyncTreeSummarizationStrategy, BaseSynthesisStrategy

chatbot/bot/conversation/ctx_strategy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33
from typing import Any, List, Union
44

55
import nest_asyncio
6+
from entities.document import Document
67
from helpers.log import get_logger
7-
from langchain_core.documents import Document
88

99
from bot.client.lama_cpp_client import LamaCppClient
1010

chatbot/bot/memory/embedder.py

Lines changed: 73 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,80 @@
1-
from abc import ABC
1+
from abc import ABC, abstractmethod
22
from typing import Any
33

4-
from langchain.embeddings import HuggingFaceEmbeddings
5-
64

75
class Embedder(ABC):
8-
embedder: Any
6+
@abstractmethod
7+
def embed_documents(self, texts: list[str]) -> list[list[float]]:
8+
"""Embed search docs."""
9+
10+
@abstractmethod
11+
def embed_query(self, text: str) -> list[float]:
12+
"""Embed query text."""
13+
14+
15+
class HuggingFaceEmbedder(Embedder):
16+
"""HuggingFace sentence_transformers embedding models.
17+
18+
To use, you should have the ``sentence_transformers`` python package installed.
19+
"""
20+
21+
client: Any #: :meta private:
22+
model_name: str = "all-MiniLM-L6-v2"
23+
"""Model name to use."""
24+
cache_folder: str | None = None
25+
"""Path to store models.
26+
Can be also set by SENTENCE_TRANSFORMERS_HOME environment variable."""
27+
model_kwargs: dict[str, Any] = {}
28+
"""Keyword arguments to pass to the model."""
29+
encode_kwargs: dict[str, Any] = {}
30+
"""Keyword arguments to pass when calling the `encode` method of the model."""
31+
multi_process: bool = False
32+
"""Run encode() on multiple GPUs."""
33+
34+
def __init__(self, **kwargs: Any):
35+
"""Initialize the sentence_transformer."""
36+
super().__init__(**kwargs)
37+
try:
38+
import sentence_transformers
39+
40+
except ImportError as exc:
41+
raise ImportError(
42+
"Could not import sentence_transformers python package. "
43+
"Please install it with `pip install sentence-transformers`."
44+
) from exc
45+
46+
self.client = sentence_transformers.SentenceTransformer(
47+
self.model_name, cache_folder=self.cache_folder, **self.model_kwargs
48+
)
49+
50+
def embed_documents(self, texts: list[str]) -> list[list[float]]:
51+
"""Compute doc embeddings using a HuggingFace transformer model.
52+
53+
Args:
54+
texts: The list of texts to embed.
55+
56+
Returns:
57+
List of embeddings, one for each text.
58+
"""
59+
import sentence_transformers
60+
61+
texts = list(map(lambda x: x.replace("\n", " "), texts))
62+
if self.multi_process:
63+
pool = self.client.start_multi_process_pool()
64+
embeddings = self.client.encode_multi_process(texts, pool)
65+
sentence_transformers.SentenceTransformer.stop_multi_process_pool(pool)
66+
else:
67+
embeddings = self.client.encode(texts, **self.encode_kwargs)
68+
69+
return embeddings.tolist()
970

10-
def get_embedding(self):
11-
return self.embedder
71+
def embed_query(self, text: str) -> list[float]:
72+
"""Compute query embeddings using a HuggingFace transformer model.
1273
74+
Args:
75+
text: The text to embed.
1376
14-
class EmbedderHuggingFace(Embedder):
15-
def __init__(self, model_name: str = "all-MiniLM-L6-v2"):
16-
self.embedder = HuggingFaceEmbeddings(model_name=model_name)
77+
Returns:
78+
Embeddings for the text.
79+
"""
80+
return self.embed_documents([text])[0]

chatbot/bot/memory/vector_memory.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
from typing import Any, Dict, List, Tuple
22

33
from cleantext import clean
4+
from entities.document import Document
45
from helpers.log import get_logger
5-
from langchain.vectorstores import Chroma
6-
from langchain_core.documents import Document
6+
from vector_database.chroma import Chroma
77

88
logger = get_logger(__name__)
99

@@ -99,10 +99,9 @@ def similarity_search(
9999
def create_memory_index(embedding: Any, chunks: List, vector_store_path: str):
100100
texts = [clean(doc.page_content, no_emoji=True) for doc in chunks]
101101
metadatas = [doc.metadata for doc in chunks]
102-
memory_index = Chroma.from_texts(
102+
Chroma.from_texts(
103103
texts=texts,
104104
embedding=embedding,
105105
metadatas=metadatas,
106106
persist_directory=vector_store_path,
107107
)
108-
memory_index.persist()

chatbot/cli/rag_chatbot.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from bot.client.lama_cpp_client import LamaCppClient
77
from bot.conversation.conversation_retrieval import ConversationRetrieval
88
from bot.conversation.ctx_strategy import get_ctx_synthesis_strategies, get_ctx_synthesis_strategy
9-
from bot.memory.embedder import EmbedderHuggingFace
9+
from bot.memory.embedder import HuggingFaceEmbedder
1010
from bot.memory.vector_memory import VectorMemory
1111
from bot.model.model_settings import get_model_setting, get_models
1212
from helpers.log import get_logger
@@ -135,7 +135,7 @@ def main(parameters):
135135

136136
conversation = ConversationRetrieval(llm)
137137

138-
embedding = EmbedderHuggingFace().get_embedding()
138+
embedding = HuggingFaceEmbedder()
139139
index = VectorMemory(vector_store_path=str(vector_store_path), embedding=embedding)
140140

141141
loop(conversation, synthesis_strategy, index, parameters)

chatbot/document_loader/format.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
from enum import Enum
2+
3+
4+
class Format(Enum):
5+
MARKDOWN = "markdown"
6+
HTML = "html"
7+
8+
9+
SUPPORTED_FORMATS = {
10+
Format.MARKDOWN.value: [
11+
# First, try to split along Markdown headings (starting with level 2)
12+
"\n#{1,6} ",
13+
# Note the alternative syntax for headings (below) is not handled here
14+
# Heading level 2
15+
# ---------------
16+
# End of code block
17+
"```\n",
18+
# Horizontal lines
19+
"\n\\*\\*\\*+\n",
20+
"\n---+\n",
21+
"\n___+\n",
22+
# Note that this splitter doesn't handle horizontal lines defined
23+
# by *three or more* of ***, ---, or ___, but this is not handled
24+
"\n\n",
25+
"\n",
26+
" ",
27+
"",
28+
],
29+
Format.HTML.value: [
30+
# First, try to split along HTML tags
31+
"<body",
32+
"<div",
33+
"<p",
34+
"<br",
35+
"<li",
36+
"<h1",
37+
"<h2",
38+
"<h3",
39+
"<h4",
40+
"<h5",
41+
"<h6",
42+
"<span",
43+
"<table",
44+
"<tr",
45+
"<td",
46+
"<th",
47+
"<ul",
48+
"<ol",
49+
"<header",
50+
"<footer",
51+
"<nav",
52+
# Head
53+
"<head",
54+
"<style",
55+
"<script",
56+
"<meta",
57+
"<title",
58+
"",
59+
],
60+
}
61+
62+
63+
def get_separators(format: str):
64+
separators = SUPPORTED_FORMATS.get(format)
65+
66+
# validate input
67+
if separators is None:
68+
raise KeyError(format + " is a not supported format")
69+
70+
return separators

chatbot/document_loader/text_splitter.py

Lines changed: 6 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -2,74 +2,13 @@
22
import logging
33
import re
44
from abc import ABC, abstractmethod
5-
from enum import Enum
65
from typing import Any, Callable, Iterable
76

87
from entities.document import Document
98

10-
logger = logging.getLogger(__name__)
11-
12-
13-
class Format(str, Enum):
14-
MARKDOWN = "markdown"
15-
HTML = "html"
16-
9+
from document_loader.format import get_separators
1710

18-
def __get_separators(format: Format) -> list[str]:
19-
if format == Format.MARKDOWN:
20-
return [
21-
# First, try to split along Markdown headings (starting with level 2)
22-
"\n#{1,6} ",
23-
# Note the alternative syntax for headings (below) is not handled here
24-
# Heading level 2
25-
# ---------------
26-
# End of code block
27-
"```\n",
28-
# Horizontal lines
29-
"\n\\*\\*\\*+\n",
30-
"\n---+\n",
31-
"\n___+\n",
32-
# Note that this splitter doesn't handle horizontal lines defined
33-
# by *three or more* of ***, ---, or ___, but this is not handled
34-
"\n\n",
35-
"\n",
36-
" ",
37-
"",
38-
]
39-
elif format == Format.HTML:
40-
return [
41-
# First, try to split along HTML tags
42-
"<body",
43-
"<div",
44-
"<p",
45-
"<br",
46-
"<li",
47-
"<h1",
48-
"<h2",
49-
"<h3",
50-
"<h4",
51-
"<h5",
52-
"<h6",
53-
"<span",
54-
"<table",
55-
"<tr",
56-
"<td",
57-
"<th",
58-
"<ul",
59-
"<ol",
60-
"<header",
61-
"<footer",
62-
"<nav",
63-
# Head
64-
"<head",
65-
"<style",
66-
"<script",
67-
"<meta",
68-
"<title",
69-
"",
70-
]
71-
else:
72-
raise ValueError(f"Language {format} is not supported! " f"Please choose from {list(Format)}")
11+
logger = logging.getLogger(__name__)
7312

7413

7514
class TextSplitter(ABC):
@@ -157,10 +96,9 @@ def _merge_splits(self, splits: Iterable[str], separator: str) -> list[str]:
15796
# We now want to combine these smaller pieces into medium size
15897
# chunks to send to the LLM.
15998
separator_len = self._length_function(separator)
160-
161-
docs = []
162-
current_doc: list[str] = []
99+
docs, current_doc = [], []
163100
total = 0
101+
164102
for d in splits:
165103
_len = self._length_function(d)
166104
if total + _len + (separator_len if len(current_doc) > 0 else 0) > self._chunk_size:
@@ -281,7 +219,7 @@ def _split_text_with_regex(text: str, separator: str, keep_separator: bool) -> l
281219
return [s for s in splits if s != ""]
282220

283221

284-
def create_recursive_text_splitter(format: Format, **kwargs: Any) -> RecursiveCharacterTextSplitter:
222+
def create_recursive_text_splitter(format: str, **kwargs: Any) -> RecursiveCharacterTextSplitter:
285223
"""
286224
Factory function to create a RecursiveCharacterTextSplitter instance based on the specified format.
287225
@@ -292,5 +230,5 @@ def create_recursive_text_splitter(format: Format, **kwargs: Any) -> RecursiveCh
292230
Returns:
293231
An instance of RecursiveCharacterTextSplitter configured with the appropriate separators.
294232
"""
295-
separators = __get_separators(format)
233+
separators = get_separators(format)
296234
return RecursiveCharacterTextSplitter(separators=separators, **kwargs)

chatbot/memory_builder.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from pathlib import Path
44
from typing import List
55

6-
from bot.memory.embedder import EmbedderHuggingFace
6+
from bot.memory.embedder import HuggingFaceEmbedder
77
from bot.memory.vector_memory import VectorMemory
88
from document_loader.loader import DirectoryLoader
99
from document_loader.text_splitter import Format, create_recursive_text_splitter
@@ -62,7 +62,7 @@ def build_memory_index(docs_path: Path, vector_store_path: str, chunk_size: int,
6262
logger.info(f"Number of generated chunks: {len(chunks)}")
6363

6464
logger.info("Creating memory index...")
65-
embedding = EmbedderHuggingFace().get_embedding()
65+
embedding = HuggingFaceEmbedder()
6666
VectorMemory.create_memory_index(embedding, chunks, vector_store_path)
6767
logger.info("Memory Index has been created successfully!")
6868

chatbot/rag_chatbot_app.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
get_ctx_synthesis_strategies,
1212
get_ctx_synthesis_strategy,
1313
)
14-
from bot.memory.embedder import EmbedderHuggingFace
14+
from bot.memory.embedder import HuggingFaceEmbedder
1515
from bot.memory.vector_memory import VectorMemory
1616
from bot.model.model_settings import get_model_setting, get_models
1717
from helpers.log import get_logger
@@ -51,7 +51,7 @@ def load_index(vector_store_path: Path) -> VectorMemory:
5151
Returns:
5252
VectorMemory: An instance of the VectorMemory class with the loaded index.
5353
"""
54-
embedding = EmbedderHuggingFace().get_embedding()
54+
embedding = HuggingFaceEmbedder()
5555
index = VectorMemory(vector_store_path=str(vector_store_path), embedding=embedding)
5656

5757
return index

0 commit comments

Comments
 (0)