Skip to content

Commit

Permalink
Add joiner param to text_helper.text_split() for better control o…
Browse files Browse the repository at this point in the history
…f regex separator handling.
  • Loading branch information
uogbuji committed Jun 7, 2024
1 parent 917f777 commit 1e5dbd4
Show file tree
Hide file tree
Showing 5 changed files with 54 additions and 22 deletions.
2 changes: 1 addition & 1 deletion pylib/__about__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@
# SPDX-License-Identifier: Apache-2.0
# ogbujipt.about

__version__ = '0.9.1'
__version__ = '0.9.2'
4 changes: 2 additions & 2 deletions pylib/embedding/pgvector.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ class PGVectorHelper:
* min_max_size: Tuple of minimum and maximum number of connections to maintain in the pool.
Defaults to (10, 20)
'''
def __init__(self, embedding_model, table_name: str, pool: asyncpg.pool.Pool):
def __init__(self, embedding_model, table_name: str, pool):
'''
If you don't already have a connection pool, construct using the PGvectorHelper.from_pool_params() method
Expand All @@ -70,7 +70,7 @@ def __init__(self, embedding_model, table_name: str, pool: asyncpg.pool.Pool):
table_name: PostgresQL table. Checked to restrict to alphanumeric characters & underscore
pool: asyncpg connection pool instance
pool: asyncpg connection pool instance (asyncpg.pool.Pool)
'''
if not PREREQS_AVAILABLE:
raise RuntimeError('pgvector not installed, you can run `pip install pgvector asyncpg`')
Expand Down
15 changes: 10 additions & 5 deletions pylib/embedding/qdrant.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,10 +94,14 @@ def __init__(self, name, embedding_model, db=None,
self._embedding_model = embedding_model
else:
raise ValueError('embedding_model must be a SentenceTransformer object')

if self.db:
# Assume any passed-in DB has been initialized
# Has the passed-in DB has been initialized?
self._db_initialized = True
try:
self.db.get_collection(self.name)
except ValueError:
self._db_initialized = False
elif not QDRANT_AVAILABLE:
raise RuntimeError('Qdrant not installed, you can run `pip install qdrant-client`')
else:
Expand Down Expand Up @@ -204,7 +208,8 @@ def search(self, query, **kwargs):
limit - maximum number of results to return (useful for top-k query)
'''
if not self._db_initialized:
raise RuntimeError('Qdrant Collection must be initialized before searching its contents.')
warnings.warn('Qdrant Collection must be initialized. No contents.')
return []

if query.__class__.__name__ != 'str':
raise ValueError('query must be a string')
Expand All @@ -216,8 +221,8 @@ def count(self):
Return the count of items in this Qdrant collection
'''
if not self._db_initialized:
raise RuntimeError('Qdrant Collection must be initialized before counting its contents.')
warnings.warn('Qdrant Collection must be initialized. No contents.')
return 0
# This ugly declaration just gets the count as an integer
current_count = int(str(self.db.count(self.name)).partition('=')[-1])
return current_count

51 changes: 38 additions & 13 deletions pylib/text_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@
'''
import re
import warnings
from itertools import zip_longest


def text_split(text: str, chunk_size: int, separator: str='\n\n', len_func=len):
# XXX: Do we want a case-insensitive separator regex flag?
def text_split(text: str, chunk_size: int, separator: str='\n\n', joiner=None, len_func=len):
'''
Split string and generate the sequence of chunks
Expand All @@ -21,13 +23,21 @@ def text_split(text: str, chunk_size: int, separator: str='\n\n', len_func=len):
# Notice case sensitivity, plus the fact that the separator is not included in the chunks
>>> list(text_split('She sells seashells by the seashore', chunk_size=10, separator='s'))
['She ', 'ells ', 'ea', 'hell', ' by the ', 'ea', 'hore']
>>> list(text_split('She\tsells seashells\tby the seashore', chunk_size=10, separator='\\s'))
['She\tsells', 'seashells', 'by the', 'seashore']
>>> list(text_split('She\tsells seashells\tby the seashore', chunk_size=10, separator='\\s', joiner=' '))
['She sells', 'seashells', 'by the', 'seashore']
Args:
text (str): String to be split into chunks
chunk_size (int): Guidance on maximum length (based on distance_function) of each chunk
seperator (str, optional): String that already splits "text" into sections
seperator (str, optional): Regex used to split `text` into sections. Do not include outer capturing parenthesis.
Don't forget to use escaping where necessary.
joiner (str, optional): Exact string used to rejoin any sections in order to meet target length
defaults to using the literal match from the separator
len_func (callable, optional): Function to measure chunk length, len() by default
Expand All @@ -44,11 +54,18 @@ def text_split(text: str, chunk_size: int, separator: str='\n\n', len_func=len):
if ((not isinstance(chunk_size, int)) or (chunk_size <= 0)):
raise ValueError(f'chunk_size must be a positive integer, got {chunk_size}.')

# Split up the text by the separator
# FIXME: Need a step for escaping regex
# Split the text by the separator
if joiner is None:
separator = f'({separator})'
sep_pat = re.compile(separator)
fine_split = re.split(sep_pat, text)
separator_len = len_func(separator)
raw_split = re.split(sep_pat, text)

# Rapid aid to understanding following logic:
# data = ['a',' ','b','\t','c']
# list(zip_longest(data[0::2], data[1::2], fillvalue=''))
# →[('a', ' '), ('b', '\t'), ('c', '')]
fine_split = ([ i for i in zip_longest(raw_split[0::2], raw_split[1::2], fillvalue='') ]
if joiner is None else re.split(sep_pat, text))

if len(fine_split) <= 1:
warnings.warn(f'No splits detected. Perhaps a problem with separator? ({repr(separator)})?')
Expand All @@ -57,28 +74,32 @@ def text_split(text: str, chunk_size: int, separator: str='\n\n', len_func=len):
chunk_len = 0

for fs in fine_split:
(fs, sep) = fs if joiner is None else (fs, joiner)
if not fs: continue # noqa E701
# print(fs)
sep_len = len_func(sep)
len_fs = len_func(fs)
# if len_fs > chunk_size:
# warnings.warn(f'One of the splits is larger than the chunk size. '
# f'Consider increasing the chunk size or splitting the text differently.')

if chunk_len + len_fs > chunk_size:
yield separator.join(curr_chunk)
curr_chunk, chunk_len = [fs], len_fs
chunk = ''.join(curr_chunk[:-1])
if chunk: yield chunk # noqa E701
curr_chunk, chunk_len = [fs, sep], len_fs + sep_len
else:
curr_chunk.append(fs)
chunk_len += len_fs + separator_len
curr_chunk.extend((fs, sep))
chunk_len += len_fs + sep_len

if curr_chunk:
yield separator.join(curr_chunk)
chunk = ''.join(curr_chunk[:-1])
if chunk: yield chunk # noqa E701


def text_split_fuzzy(text: str,
chunk_size: int,
chunk_overlap: int=0,
separator: str='\n\n',
joiner=None,
len_func=len
):
'''
Expand All @@ -97,7 +118,11 @@ def text_split_fuzzy(text: str,
chunk_overlap (int, optional): Number of characters to overlap at the edges of chunks
seperator (str, optional): String that already splits "text" into sections
seperator (str, optional): Regex used to split `text` into sections. Do not include outer capturing parenthesis.
Don't forget to use escaping where necessary.
joiner (str, optional): Exact string used to rejoin any sections in order to meet target length
defaults to using the literal match from the separator
len_func (callable, optional): Function to measure chunk length, len() by default
Expand Down
4 changes: 3 additions & 1 deletion test/test_text_splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,9 @@ def test_zero_overlap(LOREM_IPSUM):
assert chunks[0] == 'Lorem ipsum dolor sit amet, consectetur adipiscing elit. Fusce vestibulum nisi eget mauris'
assert chunks[2] == 'massa, in varius nulla ex vel ipsum. Nullam vitae eros nec ante sagittis luctus. Nullam scelerisque'
assert chunks[3] == 'dolor eu orci iaculis, at convallis nulla luctus. Praesent eget ex id arcu facilisis varius vel id'
assert chunks[-1] == 'elit.'
assert chunks[-1] == 'quam justo at elit.'
for chunk in chunks:
assert len(chunk) <= 100


if __name__ == '__main__':
Expand Down

0 comments on commit 1e5dbd4

Please sign in to comment.