Skip to content

Commit

Permalink
Merge branch 'main' of github.com:uogbuji/OgbujiPT
Browse files Browse the repository at this point in the history
  • Loading branch information
choccccy committed Jun 25, 2024
2 parents 62ca64e + b2ed163 commit 3c9608f
Show file tree
Hide file tree
Showing 9 changed files with 87 additions and 74 deletions.
21 changes: 21 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,27 @@ Notable changes to Format based on [Keep a Changelog](https://keepachangelog.co
-->

## [0.9.2] - 20240625

### Added

- `joiner` param to `text_helper.text_split()` for better control of regex separator handling
- query filter mix-in, `embedding.pgvector.match_oneof()`, for use with `meta_filter` argument to `DB.search`

### Changed

- Index word loom items by their literal default language text, as well
- Cleaned up PGVector query-building logic

### Fixed

- `llm_wrapper.llm_response` objects to handle tool calls
- failure of some matching scenarios in `embedding.pgvector.match_exact()`

### Removed

- Previously deprecated `first_choice_text` & `first_choice_message` methods

## [0.9.1] - 20240604

### Fixed
Expand Down
6 changes: 5 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,10 @@ pytest test

If you want to make contributions to the project, please [read these notes](https://github.com/OoriData/OgbujiPT/wiki/Notes-for-contributors).

# Resources

* [Against mixing environment setup with code](https://huggingface.co/blog/ucheog/separate-env-setup-from-code)

# License

Apache 2. For tha culture!
Expand Down Expand Up @@ -189,7 +193,7 @@ I mentioned the bias to software engineering, but what does this mean?

## Does this support GPU for locally-hosted models

Yes, but you have to make sure you set up your back end LLm server (llama.cpp or text-generation-webui) with GPU, and properly configure the model you load into it. If you can use the webui to query your model and get GPU usage, that will also apply here in OgbujiPT.
Yes, but you have to make sure you set up your back end LLM server (llama.cpp or text-generation-webui) with GPU, and properly configure the model you load into it.

Many install guides I've found for Mac, Linux and Windows touch on enabling GPU, but the ecosystem is still in its early days, and helpful resouces can feel scattered.

Expand Down
7 changes: 4 additions & 3 deletions pylib/embedding/pgvector.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,9 +93,6 @@ def __init__(self, embedding_model, table_name: str, pool):
else:
raise ValueError('embedding_model must be a SentenceTransformer object or None')

self.table_name = table_name
self.pool = pool

@classmethod
async def from_conn_params(cls, embedding_model, table_name, host, port, db_name, user, password) -> 'PGVectorHelper': # noqa: E501
'''
Expand All @@ -117,6 +114,10 @@ async def from_conn_params(cls, embedding_model, table_name, host, port, db_name
async def init_pool(conn):
'''
Initialize vector extension for a connection from a pool
Can be invoked from upstream if they're managing the connection pool themselves
If they choose to have us create a connection pool (e.g. from_conn_params), it will use this
'''
await conn.execute('CREATE EXTENSION IF NOT EXISTS vector;')
await register_vector(conn)
Expand Down
22 changes: 11 additions & 11 deletions pylib/embedding/pgvector_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,22 +182,17 @@ async def search(
query_embedding = list(self._embedding_model.encode(text))

# Build where clauses
if threshold is None:
# No where clauses, so don't bother with the WHERE keyword
where_clauses = []
query_args = [query_embedding]
else: # construct where clauses
where_clauses = []
query_args = [query_embedding]
if threshold is not None:
query_args.append(threshold)
where_clauses.append(THRESHOLD_WHERE_CLAUSE.format(query_threshold=f'${len(query_args)+1}'))
query_args = [query_embedding]
where_clauses = []
if threshold is not None:
query_args.append(threshold)
where_clauses.append(THRESHOLD_WHERE_CLAUSE.format(query_threshold=f'${len(query_args)}'))

for mf in meta_filter:
assert callable(mf), 'All meta_filter items must be callable'
clause, pval = mf()
where_clauses.append(clause.format(len(query_args)+1))
query_args.append(pval)
where_clauses.append(clause.format(len(query_args)))

where_clauses_str = 'WHERE\n' + 'AND\n'.join(where_clauses) if where_clauses else ''

Expand All @@ -206,6 +201,11 @@ async def search(
else:
limit_clause = ''

# print(QUERY_DATA_TABLE.format(table_name=self.table_name, where_clauses=where_clauses_str,
# limit_clause=limit_clause,
# ))
# print(query_args)

# Execute the search via SQL
async with self.pool.acquire() as conn:
# Uncomment to debug
Expand Down
53 changes: 9 additions & 44 deletions pylib/llm_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@
'''

import os
import json
import asyncio
import concurrent.futures
from functools import partial
from typing import List
import warnings

from amara3 import iri

Expand Down Expand Up @@ -70,8 +70,13 @@ def from_openai_chat(response):
if 'message' in c:
c['message'] = llm_response(c['message'])
rc1 = resp['choices'][0]
# print(f'from_openai_chat: {rc1 =}')
resp['first_choice_text'] = rc1['text'] if 'text' in rc1 else rc1['message']['content']
# No response message content if a tool call is invoked
if rc1.get('message', {}).get('tool_calls'):
# WTH does OpenAI have these arguments properties as plain text? Seems a massive layering violation
for tc in rc1['message']['tool_calls']:
tc['function']['arguments_obj'] = json.loads(tc['function']['arguments'])
else:
resp['first_choice_text'] = rc1['text'] if 'text' in rc1 else rc1['message']['content']
else:
resp['first_choice_text'] = resp['content']
return resp
Expand Down Expand Up @@ -190,8 +195,7 @@ async def __call__(self, prompt, api_func=None, **kwargs):
kwargs (dict, optional): Extra parameters to pass to the model via API.
See Completions.create in OpenAI API, but in short, these:
best_of, echo, frequency_penalty, logit_bias, logprobs, max_tokens, n
presence_penalty, seed, stop, stream, suffix, temperature, top_p, user
q
presence_penalty, seed, stop, stream, suffix, temperature, top_p, userq
Returns:
dict: JSON response from the LLM
'''
Expand Down Expand Up @@ -245,19 +249,6 @@ def available_models(self) -> List[str]:
raise RuntimeError(f'Unexpected response from {self.base_url}/models:\n{repr(resp)}')
return [ i['id'] for i in resp['data'] ]

@staticmethod
def first_choice_text(response):
'''
Given an OpenAI-compatible API simple completion response, return the first choice text
'''
warnings.warn('The first_choice_text method is deprecated; use the first_choice_text attribute or key instead', DeprecationWarning, stacklevel=2) # noqa E501
try:
return response.choices[0].text
except AttributeError:
raise RuntimeError(
f'''Response does not appear to be an OpenAI API completion structure, as expected:
{repr(response)}''')


class openai_chat_api(openai_api):
'''
Expand Down Expand Up @@ -323,19 +314,6 @@ async def __call__(self, prompt, api_func=None, **kwargs):
# Haven't implemented any OpenAI API calls that are async, so just call the sync version
return self.call(prompt, api_func, **kwargs)

@staticmethod
def first_choice_message(response):
'''
Given an OpenAI-compatible API chat completion response, return the first choice message content
'''
warnings.warn('The first_choice_message method is deprecated; use the first_choice_text attribute or key instead', DeprecationWarning, stacklevel=2) # noqa E501
try:
return response.choices[0].message.content
except AttributeError:
raise RuntimeError(
f'''Response does not appear to be an OpenAI API chat-style completion structure, as expected:
{repr(response)}''')


class llama_cpp_http(llm_wrapper):
'''
Expand Down Expand Up @@ -466,19 +444,6 @@ async def __call__(self, messages, req='/v1/chat/completions', timeout=30.0, api
else:
raise RuntimeError(f'Unexpected response from {self.base_url}{req}:\n{repr(result)}')

@staticmethod
def first_choice_message(response):
'''
Given an OpenAI-compatible API chat completion response, return the first choice message content
'''
warnings.warn('The first_choice_message method is deprecated; use the first_choice_text attribute or key instead', DeprecationWarning, stacklevel=2) # noqa E501
try:
return response['choices'][0]['message']['content']
except (IndexError, KeyError):
raise RuntimeError(
f'''Response does not appear to be a llama.cpp API chat-style completion structure, as expected:
{repr(response)}''')


class ctransformer:
'''
Expand Down
18 changes: 15 additions & 3 deletions pylib/word_loom.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,9 @@ def clone(self, value=None, deflang=None, altlang=None, meta=None, markers=None)
def load(fp_or_str, lang='en', preserve_key=False):
'''
Read a word loom and return the tables as top-level result mapping
Loads the TOML, then selects text by given language
Loads the TOML
Return a dict of the language items, indexed by the TOML key as well as its default language text
fp_or_str - file-like object or string containing TOML
lang - select oly texts in this language (default: 'en')
Expand All @@ -97,9 +99,12 @@ def load(fp_or_str, lang='en', preserve_key=False):
>>> loom = word_loom.load(fp)
>>> loom['test_prompt_joke'].meta
{'tag': 'humor', 'updated': '2024-01-01'}
>>> str(loom['test_prompt_joke'])
>>> actual_text = loom['test_prompt_joke']
>>> str(actual_text)
'Tell me a funny joke about {topic}\n'
>>> str(loom[str(actual_text)])
'Tell me a funny joke about {topic}\n'
>>> loom['test_prompt_joke'].in_lang('fr')
>>> loom[str(actual_text)].in_lang('fr')
'Dites-moi une blague drôle sur {topic}\n'
'''
# Ensure we have a file-like object
Expand Down Expand Up @@ -135,5 +140,12 @@ def load(fp_or_str, lang='en', preserve_key=False):
meta = {kk: vv for kk, vv in v.items() if (not kk.startswith('_') and kk not in ('text', 'markers'))}
if preserve_key:
meta['_key'] = k
if k in texts:
warnings.warn(f'Key {k} duplicates an existing item, which will be overwritten')
texts[k] = T(text, lang, altlang=altlang, meta=meta, markers=markers)
# Also index by literal text
if text in texts:
warnings.warn(
f'Item default language text {text[:20]} duplicates an existing item, which will be overwritten')
texts[text] = T(text, lang, altlang=altlang, meta=meta, markers=markers)
return texts
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,12 @@ classifiers = [
"Programming Language :: Python",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Programming Language :: Python :: Implementation :: CPython",
"Programming Language :: Python :: Implementation :: PyPy",
]
dependencies = [
"openai>=1.1.0",
"python-dotenv",
"tomli",
"amara3.iri"
]
Expand Down Expand Up @@ -78,7 +78,7 @@ cov = [
]

[[tool.hatch.envs.all.matrix]]
python = ["3.10", "3.11"]
python = ["3.10", "3.11", "3.12"]

[tool.hatch.envs.lint]
detached = true
Expand Down
6 changes: 3 additions & 3 deletions test/test_ogbujipt.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,14 @@

# import pytest

from ogbujipt.llm_wrapper import openai_chat_api
from ogbujipt.llm_wrapper import llm_response #, openai_chat_api

def test_oapi_first_choice_text(OPENAI_TEXT_RESPONSE_OBJECT):
text1 = openai_chat_api.first_choice_text(OPENAI_TEXT_RESPONSE_OBJECT)
text1 = llm_response.from_openai_chat(OPENAI_TEXT_RESPONSE_OBJECT).first_choice_text
assert text1 == '…is an exceptional employee who has made significant contributions to our company.'

def test_oapi_first_choice_message(OPENAI_MSG_RESPONSE_OBJECT):
msg1 = openai_chat_api.first_choice_message(OPENAI_MSG_RESPONSE_OBJECT)
msg1 = llm_response.from_openai_chat(OPENAI_MSG_RESPONSE_OBJECT).first_choice_text
assert msg1 == '…is an exceptional employee who has made significant contributions to our company.'


Expand Down
24 changes: 17 additions & 7 deletions test/test_word_loom.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

import os
import sys
# import pkgutil
# from itertools import chain

import pytest

Expand Down Expand Up @@ -49,19 +49,29 @@ def test_load_fp_vs_str(SAMPLE_TOML_STR, SAMPLE_TOML_FP):
def test_sample_texts_check(SAMPLE_TOML_STR):
# print(SAMPLE_TOML)
loom = word_loom.load(SAMPLE_TOML_STR)
assert list(sorted(loom.keys())) == ['davinci3_instruct_system', 'hello_translated', 'i18n_context', 'write_i18n_advocacy']
assert list(sorted([v[:20] for v in loom.values()])) == ['Hello', 'Internationalization', 'Obey the instruction', '{davinci3_instruct_s']
assert [v.markers or [] for v in loom.values()] == [[], [], ['davinci3_instruct_system', 'i18n_context'], []]
# default language text is also a key
assert len(loom.keys()) == 8
for k in ['davinci3_instruct_system', 'hello_translated', 'i18n_context', 'write_i18n_advocacy']:
assert k in loom.keys()
assert 'Hello' in loom.keys()

# loom_dlt = set([v[:20] for v in loom.values()])

assert len(set(loom.values())) == 4
for k in ['Hello', 'Internationalization', 'Obey the instruction', '{davinci3_instruct_s']:
assert k in [v[:20] for v in loom.values()]

assert [v.markers or [] for v in loom.values()] == [[], [], [], [], ['davinci3_instruct_system', 'i18n_context'], ['davinci3_instruct_system', 'i18n_context'], [], []]
assert loom['davinci3_instruct_system'].lang == 'en'

# Default language is English
loom1 = word_loom.load(SAMPLE_TOML_STR, lang='en')
assert loom1 == loom

loom = word_loom.load(SAMPLE_TOML_STR, lang='fr')
assert list(sorted(loom.keys())) == ['goodbye_translated', 'hardcoded_food', 'translate_request']
assert list(sorted([v[:20] for v in loom.values()])) == ['Adieu', 'Comment dit-on en an', 'pomme de terre']
assert [v.markers or [] for v in loom.values()] == [['hardcoded_food'], [], []]
assert list(sorted(loom.keys())) == ['Adieu', 'Comment dit-on en anglais: {hardcoded_food}?', 'goodbye_translated', 'hardcoded_food', 'pomme de terre', 'translate_request']
assert list(sorted(set([v[:20] for v in loom.values()]))) == ['Adieu', 'Comment dit-on en an', 'pomme de terre']
assert [v.markers or [] for v in loom.values()] == [['hardcoded_food'], ['hardcoded_food'], [], [], [], []]
assert loom['hardcoded_food'].lang == 'fr'


Expand Down

0 comments on commit 3c9608f

Please sign in to comment.