Skip to content

Commit

Permalink
Together AI issues (#185)
Browse files Browse the repository at this point in the history
* Ensure all configurations work.
  • Loading branch information
vivekuppal committed Mar 22, 2024
1 parent 91edbbb commit 6b96e32
Show file tree
Hide file tree
Showing 5 changed files with 82 additions and 19 deletions.
20 changes: 12 additions & 8 deletions app/transcribe/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,18 +85,22 @@ def handle_args_batch_tasks(args: argparse.Namespace, global_vars: Transcription
if args.validate_api_key is not None:
chat_inference_provider = config['General']['chat_inference_provider']
if chat_inference_provider == 'openai':
base_url = config['OpenAI']['base_url']
settings_section = 'OpenAI'
elif chat_inference_provider == 'together':
base_url = config['Together']['base_url']
settings_section = 'Together'

if utilities.is_api_key_valid(api_key=args.validate_api_key, base_url=base_url):
base_url = config[settings_section]['base_url']
model = config[settings_section]['ai_model']

if utilities.is_api_key_valid(api_key=args.validate_api_key, base_url=base_url, model=model):
print('The api_key is valid')
base_url = config['OpenAI']['base_url']
client = openai.OpenAI(api_key=args.validate_api_key, base_url=base_url)
models = utilities.get_available_models(client=client)
print('Available models: ')
for model in models:
print(f' {model}')

if base_url != 'https://api.together.xyz':
models = utilities.get_available_models(client=client)
print('Available models: ')
for model in models:
print(f' {model}')
client.close()
else:
print('The api_key is not valid')
Expand Down
42 changes: 36 additions & 6 deletions app/transcribe/gpt_responder.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import conversation
import constants
from tsutils import app_logging as al
from tsutils import duration
from tsutils import duration, utilities


root_logger = al.get_logger()
Expand Down Expand Up @@ -47,8 +47,17 @@ def summarize(self) -> str:
"""
root_logger.info(GPTResponder.summarize.__name__)

if self.config['OpenAI']['api_key'] in ('', 'API_KEY'):
# Cannot summarize without connection to LLM
chat_inference_provider = self.config['General']['chat_inference_provider']
if chat_inference_provider == 'openai':
settings_section = 'OpenAI'
elif chat_inference_provider == 'together':
settings_section = 'Together'

api_key = self.config[settings_section]['api_key']
base_url = self.config[settings_section]['base_url']
model = self.config[settings_section]['ai_model']

if not utilities.is_api_key_valid(api_key=api_key, base_url=base_url, model=model):
return None

with duration.Duration(name='OpenAI Summarize', screen=False):
Expand Down Expand Up @@ -81,7 +90,18 @@ def generate_response_from_transcript_no_check(self) -> str:
"""
try:
root_logger.info(GPTResponder.generate_response_from_transcript_no_check.__name__)
if self.config['OpenAI']['api_key'] in ('', 'API_KEY'):
chat_inference_provider = self.config['General']['chat_inference_provider']

if chat_inference_provider == 'openai':
settings_section = 'OpenAI'
elif chat_inference_provider == 'together':
settings_section = 'Together'

api_key = self.config[settings_section]['api_key']
base_url = self.config[settings_section]['base_url']
model = self.config[settings_section]['ai_model']

if not utilities.is_api_key_valid(api_key=api_key, base_url=base_url, model=model):
return None

with duration.Duration(name='OpenAI Chat Completion', screen=False):
Expand Down Expand Up @@ -172,7 +192,17 @@ def generate_response_for_selected_text(self, text: str):
"""
try:
root_logger.info(GPTResponder.generate_response_for_selected_text.__name__)
if self.config['OpenAI']['api_key'] in ('', 'API_KEY'):
chat_inference_provider = self.config['General']['chat_inference_provider']
if chat_inference_provider == 'openai':
settings_section = 'OpenAI'
elif chat_inference_provider == 'together':
settings_section = 'Together'

api_key = self.config[settings_section]['api_key']
base_url = self.config[settings_section]['base_url']
model = self.config[settings_section]['ai_model']

if not utilities.is_api_key_valid(api_key=api_key, base_url=base_url, model=model):
return None

with duration.Duration(name='OpenAI Chat Completion Selected', screen=False):
Expand Down Expand Up @@ -305,7 +335,7 @@ def __init__(self,
self.llm_client = openai.OpenAI(api_key=api_key,
base_url=base_url)
self.model = self.config['Together']['ai_model']
print(f'[INFO] Using Together for inference. Model: {self.model}')
print(f'[INFO] Using Together AI for inference. Model: {self.model}')
super().__init__(config=self.config,
convo=convo,
save_to_file=save_to_file,
Expand Down
4 changes: 3 additions & 1 deletion app/transcribe/ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,11 +487,13 @@ def create_ui_components(root, config: dict):
if chat_inference_provider == 'openai':
api_key = config['OpenAI']['api_key']
base_url = config['OpenAI']['base_url']
model = config['OpenAI']['ai_model']
elif chat_inference_provider == 'together':
api_key = config['Together']['api_key']
base_url = config['Together']['base_url']
model = config['Together']['ai_model']

if not utilities.is_api_key_valid(api_key=api_key, base_url=base_url):
if not utilities.is_api_key_valid(api_key=api_key, base_url=base_url, model=model):
# Disable buttons that interact with backend services
continuous_response_button.configure(state='disabled')
response_now_button.configure(state='disabled')
Expand Down
3 changes: 2 additions & 1 deletion examples/together/stt.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,12 @@
},
{
'role': 'user',
'content': 'Tell me about Fantasy Football',
'content': 'Are you online',
}
],
model='mistralai/Mixtral-8x7B-Instruct-v0.1',
max_tokens=1024
)

print(chat_completion.choices[0].message.content)
client.close()
32 changes: 29 additions & 3 deletions tsutils/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
import openai


valid_api_key: bool = False


def merge(first: dict, second: dict, path=[]):
"""Recursively merge two dictionaries.
For keys with different values, values in the second dictionary
Expand Down Expand Up @@ -207,18 +210,41 @@ def get_available_models(client: openai.OpenAI) -> list:
return sorted(return_val)


def is_api_key_valid(api_key: str, base_url) -> bool:
def is_api_key_valid(api_key: str, base_url: str, model: str) -> bool:
"""Check if it is valid openai compatible openai key for the provider
"""
openai.api_key = api_key

global valid_api_key # pylint: disable=W0603

if valid_api_key:
return True

openai.api_key = api_key
client = openai.OpenAI(api_key=api_key, base_url=base_url)

try:
client.models.list()
# Ideally models list is the best way to determine if api key is valid
# Some of the OpenAI compatible vendors do not support all the methods though
# client.models.list()
chat_completion = client.chat.completions.create(
messages=[
{
'role': 'system',
'content': 'You are an AI assistant',
},
{
'role': 'user',
'content': 'Are you online',
}
],
model=model,
max_tokens=1024
)
assert len(chat_completion.choices[0].message.content) > 0
client.close()
except openai.AuthenticationError as e:
print(e)
return False

valid_api_key = True
return True

0 comments on commit 6b96e32

Please sign in to comment.