Skip to content

Commit

Permalink
[#72] Bug fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
uogbuji committed Mar 5, 2024
1 parent 40e284f commit ef6ed4c
Showing 1 changed file with 13 additions and 5 deletions.
18 changes: 13 additions & 5 deletions pylib/llm_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,16 @@ async def __call__(self, prompt, req='/completion', timeout=30.0, apikey=None, *
else:
raise RuntimeError(f'Unexpected response from {self.base_url}{req}:\n{repr(result)}')

hosted_model = openai_api.hosted_model # Borrow method
def hosted_model(self) -> str:
'''
Model introspection: Query the API to find what model is being run for LLM calls
>>> from ogbujipt.llm_wrapper import llama_cpp_http
>>> llm_api = llama_cpp_http(base_url='http://localhost:8000')
>>> print(llm_api.hosted_model())
'/models/TheBloke_WizardLM-13B-V1.0-Uncensored-GGML/wizardlm-13b-v1.0-uncensored.ggmlv3.q6_K.bin'
'''
return self.available_models()[0]

def available_models(self) -> List[str]:
'''
Expand Down Expand Up @@ -353,7 +362,7 @@ class llama_cpp_http_chat(llama_cpp_http):
>>> resp = asyncio.run(llm_api(prompt_to_chat('Knock knock!')))
>>> llm_api.first_choice_message(resp)
'''
async def __call__(self, prompt, req='/v1/chat/completions', timeout=30.0, apikey=None, **kwargs):
async def __call__(self, messages, req='/v1/chat/completions', timeout=30.0, apikey=None, **kwargs):
'''
Invoke LLM with a completion request
Expand All @@ -369,13 +378,12 @@ async def __call__(self, prompt, req='/v1/chat/completions', timeout=30.0, apike
dict: JSON response from the LLM
'''
header = {'Content-Type': 'application/json'}
if apikey is None:
apikey = self.apikey
apikey = apikey or self.apikey
if apikey:
header['Authorization'] = f'Bearer {apikey}'
async with httpx.AsyncClient() as client:
# FIXME: Decide the best way to return result metadata
result = await client.post(f'{self.base_url}{req}', json={'messages': prompt, **kwargs},
result = await client.post(f'{self.base_url}{req}', json={'messages': messages, **kwargs},
headers=header, timeout=timeout)
if result.status_code == HTTP_SUCCESS:
return result.json()
Expand Down

0 comments on commit ef6ed4c

Please sign in to comment.