Skip to content

Commit

Permalink
Implemented llama.cpp-style API keys. Better HTTP error handling
Browse files Browse the repository at this point in the history
  • Loading branch information
uogbuji committed Mar 4, 2024
1 parent 54f102e commit c8f240c
Showing 1 changed file with 22 additions and 9 deletions.
31 changes: 22 additions & 9 deletions pylib/llm_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@
except ImportError:
httpx = None

HTTP_SUCCESS = 200

# In many cases of self-hosted models you just get whatever model is loaded, rather than specifying it in the API
DUMMY_MODEL = 'DUMMY_MODEL'

Expand Down Expand Up @@ -274,7 +276,7 @@ class llama_cpp_http(llm_wrapper):
>>> resp = asyncio.run(llm_api('Knock knock!', min_p=0.05))
>>> resp['content']
'''
def __init__(self, base_url, model=None, **kwargs):
def __init__(self, base_url, apikey=None, model=None, **kwargs):
'''
Args:
model (str, optional): Name of model to select form the options available on the endpoint
Expand All @@ -287,9 +289,10 @@ def __init__(self, base_url, model=None, **kwargs):
raise ImportError('httpx module not available; Perhaps try: `pip install httpx`')
self.model = model
self.base_url = base_url
self.apikey = apikey
super().__init__(model, **kwargs)

async def __call__(self, prompt, req='/completion', timeout=30.0, **kwargs):
async def __call__(self, prompt, req='/completion', timeout=30.0, apikey=None, **kwargs):
'''
Invoke the LLM with a completion request
Expand All @@ -305,14 +308,18 @@ async def __call__(self, prompt, req='/completion', timeout=30.0, **kwargs):
dict: JSON response from the LLM
'''
header = {'Content-Type': 'application/json'}
if apikey is None:
apikey = self.apikey
if apikey:
header['Authorization'] = f'Bearer {apikey}'
async with httpx.AsyncClient() as client:
# FIXME: Decide the best way to return result metadata
result = await client.post(f'{self.base_url}{req}', json={'prompt': prompt, **kwargs},
headers=header, timeout=timeout)
try:
if result.status_code == HTTP_SUCCESS:
return result.json()
except ValueError:
return result.text
else:
raise RuntimeError(f'Unexpected response from {self.base_url}{req}:\n{repr(result)}')

hosted_model = openai_api.hosted_model # Borrow method

Expand Down Expand Up @@ -346,7 +353,7 @@ class llama_cpp_http_chat(llama_cpp_http):
>>> resp = asyncio.run(llm_api(prompt_to_chat('Knock knock!')))
>>> llm_api.first_choice_message(resp)
'''
async def __call__(self, prompt, req='/v1/chat/completions', timeout=30.0, **kwargs):
async def __call__(self, prompt, req='/v1/chat/completions', timeout=30.0, apikey=None, **kwargs):
'''
Invoke LLM with a completion request
Expand All @@ -357,17 +364,23 @@ async def __call__(self, prompt, req='/v1/chat/completions', timeout=30.0, **kwa
See Completions.create in OpenAI API, but in short, these:
best_of, echo, frequency_penalty, logit_bias, logprobs, max_tokens, n
presence_penalty, seed, stop, stream, suffix, temperature, top_p, user
Returns:
dict: JSON response from the LLM
'''
header = {'Content-Type': 'application/json'}
if apikey is None:
apikey = self.apikey
if apikey:
header['Authorization'] = f'Bearer {apikey}'
async with httpx.AsyncClient() as client:
# FIXME: Decide the best way to return result metadata
result = await client.post(f'{self.base_url}{req}', json={'messages': prompt, **kwargs},
headers=header, timeout=timeout)
# print(result.text)
return result.json()
if result.status_code == HTTP_SUCCESS:
return result.json()
else:
raise RuntimeError(f'Unexpected response from {self.base_url}{req}:\n{repr(result)}')

@staticmethod
def first_choice_message(response):
Expand Down

0 comments on commit c8f240c

Please sign in to comment.