Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Together AI issues #185

Merged
merged 2 commits into from
Mar 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 12 additions & 8 deletions app/transcribe/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,18 +85,22 @@ def handle_args_batch_tasks(args: argparse.Namespace, global_vars: Transcription
if args.validate_api_key is not None:
chat_inference_provider = config['General']['chat_inference_provider']
if chat_inference_provider == 'openai':
base_url = config['OpenAI']['base_url']
settings_section = 'OpenAI'
elif chat_inference_provider == 'together':
base_url = config['Together']['base_url']
settings_section = 'Together'

if utilities.is_api_key_valid(api_key=args.validate_api_key, base_url=base_url):
base_url = config[settings_section]['base_url']
model = config[settings_section]['ai_model']

if utilities.is_api_key_valid(api_key=args.validate_api_key, base_url=base_url, model=model):
print('The api_key is valid')
base_url = config['OpenAI']['base_url']
client = openai.OpenAI(api_key=args.validate_api_key, base_url=base_url)
models = utilities.get_available_models(client=client)
print('Available models: ')
for model in models:
print(f' {model}')

if base_url != 'https://api.together.xyz':
models = utilities.get_available_models(client=client)
print('Available models: ')
for model in models:
print(f' {model}')
client.close()
else:
print('The api_key is not valid')
Expand Down
42 changes: 36 additions & 6 deletions app/transcribe/gpt_responder.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import conversation
import constants
from tsutils import app_logging as al
from tsutils import duration
from tsutils import duration, utilities


root_logger = al.get_logger()
Expand Down Expand Up @@ -47,8 +47,17 @@ def summarize(self) -> str:
"""
root_logger.info(GPTResponder.summarize.__name__)

if self.config['OpenAI']['api_key'] in ('', 'API_KEY'):
# Cannot summarize without connection to LLM
chat_inference_provider = self.config['General']['chat_inference_provider']
if chat_inference_provider == 'openai':
settings_section = 'OpenAI'
elif chat_inference_provider == 'together':
settings_section = 'Together'

api_key = self.config[settings_section]['api_key']
base_url = self.config[settings_section]['base_url']
model = self.config[settings_section]['ai_model']

if not utilities.is_api_key_valid(api_key=api_key, base_url=base_url, model=model):
return None

with duration.Duration(name='OpenAI Summarize', screen=False):
Expand Down Expand Up @@ -81,7 +90,18 @@ def generate_response_from_transcript_no_check(self) -> str:
"""
try:
root_logger.info(GPTResponder.generate_response_from_transcript_no_check.__name__)
if self.config['OpenAI']['api_key'] in ('', 'API_KEY'):
chat_inference_provider = self.config['General']['chat_inference_provider']

if chat_inference_provider == 'openai':
settings_section = 'OpenAI'
elif chat_inference_provider == 'together':
settings_section = 'Together'

api_key = self.config[settings_section]['api_key']
base_url = self.config[settings_section]['base_url']
model = self.config[settings_section]['ai_model']

if not utilities.is_api_key_valid(api_key=api_key, base_url=base_url, model=model):
return None

with duration.Duration(name='OpenAI Chat Completion', screen=False):
Expand Down Expand Up @@ -172,7 +192,17 @@ def generate_response_for_selected_text(self, text: str):
"""
try:
root_logger.info(GPTResponder.generate_response_for_selected_text.__name__)
if self.config['OpenAI']['api_key'] in ('', 'API_KEY'):
chat_inference_provider = self.config['General']['chat_inference_provider']
if chat_inference_provider == 'openai':
settings_section = 'OpenAI'
elif chat_inference_provider == 'together':
settings_section = 'Together'

api_key = self.config[settings_section]['api_key']
base_url = self.config[settings_section]['base_url']
model = self.config[settings_section]['ai_model']

if not utilities.is_api_key_valid(api_key=api_key, base_url=base_url, model=model):
return None

with duration.Duration(name='OpenAI Chat Completion Selected', screen=False):
Expand Down Expand Up @@ -305,7 +335,7 @@ def __init__(self,
self.llm_client = openai.OpenAI(api_key=api_key,
base_url=base_url)
self.model = self.config['Together']['ai_model']
print(f'[INFO] Using Together for inference. Model: {self.model}')
print(f'[INFO] Using Together AI for inference. Model: {self.model}')
super().__init__(config=self.config,
convo=convo,
save_to_file=save_to_file,
Expand Down
4 changes: 3 additions & 1 deletion app/transcribe/ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,11 +487,13 @@ def create_ui_components(root, config: dict):
if chat_inference_provider == 'openai':
api_key = config['OpenAI']['api_key']
base_url = config['OpenAI']['base_url']
model = config['OpenAI']['ai_model']
elif chat_inference_provider == 'together':
api_key = config['Together']['api_key']
base_url = config['Together']['base_url']
model = config['Together']['ai_model']

if not utilities.is_api_key_valid(api_key=api_key, base_url=base_url):
if not utilities.is_api_key_valid(api_key=api_key, base_url=base_url, model=model):
# Disable buttons that interact with backend services
continuous_response_button.configure(state='disabled')
response_now_button.configure(state='disabled')
Expand Down
3 changes: 2 additions & 1 deletion examples/together/stt.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,12 @@
},
{
'role': 'user',
'content': 'Tell me about Fantasy Football',
'content': 'Are you online',
}
],
model='mistralai/Mixtral-8x7B-Instruct-v0.1',
max_tokens=1024
)

print(chat_completion.choices[0].message.content)
client.close()
32 changes: 29 additions & 3 deletions tsutils/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
import openai


valid_api_key: bool = False


def merge(first: dict, second: dict, path=[]):
"""Recursively merge two dictionaries.
For keys with different values, values in the second dictionary
Expand Down Expand Up @@ -207,18 +210,41 @@ def get_available_models(client: openai.OpenAI) -> list:
return sorted(return_val)


def is_api_key_valid(api_key: str, base_url) -> bool:
def is_api_key_valid(api_key: str, base_url: str, model: str) -> bool:
"""Check if it is valid openai compatible openai key for the provider
"""
openai.api_key = api_key

global valid_api_key # pylint: disable=W0603

if valid_api_key:
return True

openai.api_key = api_key
client = openai.OpenAI(api_key=api_key, base_url=base_url)

try:
client.models.list()
# Ideally models list is the best way to determine if api key is valid
# Some of the OpenAI compatible vendors do not support all the methods though
# client.models.list()
chat_completion = client.chat.completions.create(
messages=[
{
'role': 'system',
'content': 'You are an AI assistant',
},
{
'role': 'user',
'content': 'Are you online',
}
],
model=model,
max_tokens=1024
)
assert len(chat_completion.choices[0].message.content) > 0
client.close()
except openai.AuthenticationError as e:
print(e)
return False

valid_api_key = True
return True