Skip to content

Commit

Permalink
bug fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
vivekuppal committed Mar 21, 2024
1 parent 91edbbb commit 8128e17
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 7 deletions.
1 change: 0 additions & 1 deletion app/transcribe/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,6 @@ def handle_args_batch_tasks(args: argparse.Namespace, global_vars: Transcription

if utilities.is_api_key_valid(api_key=args.validate_api_key, base_url=base_url):
print('The api_key is valid')
base_url = config['OpenAI']['base_url']
client = openai.OpenAI(api_key=args.validate_api_key, base_url=base_url)
models = utilities.get_available_models(client=client)
print('Available models: ')
Expand Down
32 changes: 26 additions & 6 deletions app/transcribe/gpt_responder.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import conversation
import constants
from tsutils import app_logging as al
from tsutils import duration
from tsutils import duration, utilities


root_logger = al.get_logger()
Expand Down Expand Up @@ -47,8 +47,14 @@ def summarize(self) -> str:
"""
root_logger.info(GPTResponder.summarize.__name__)

if self.config['OpenAI']['api_key'] in ('', 'API_KEY'):
# Cannot summarize without connection to LLM
chat_inference_provider = self.config['General']['chat_inference_provider']
if chat_inference_provider == 'openai':
api_key = self.config['OpenAI']['api_key']
base_url = self.config['OpenAI']['base_url']
elif chat_inference_provider == 'together':
api_key = self.config['Together']['api_key']
base_url = self.config['Together']['base_url']
if not utilities.is_api_key_valid(api_key=api_key, base_url=base_url):
return None

with duration.Duration(name='OpenAI Summarize', screen=False):
Expand Down Expand Up @@ -81,7 +87,14 @@ def generate_response_from_transcript_no_check(self) -> str:
"""
try:
root_logger.info(GPTResponder.generate_response_from_transcript_no_check.__name__)
if self.config['OpenAI']['api_key'] in ('', 'API_KEY'):
chat_inference_provider = self.config['General']['chat_inference_provider']
if chat_inference_provider == 'openai':
api_key = self.config['OpenAI']['api_key']
base_url = self.config['OpenAI']['base_url']
elif chat_inference_provider == 'together':
api_key = self.config['Together']['api_key']
base_url = self.config['Together']['base_url']
if not utilities.is_api_key_valid(api_key=api_key, base_url=base_url):
return None

with duration.Duration(name='OpenAI Chat Completion', screen=False):
Expand Down Expand Up @@ -172,7 +185,14 @@ def generate_response_for_selected_text(self, text: str):
"""
try:
root_logger.info(GPTResponder.generate_response_for_selected_text.__name__)
if self.config['OpenAI']['api_key'] in ('', 'API_KEY'):
chat_inference_provider = self.config['General']['chat_inference_provider']
if chat_inference_provider == 'openai':
api_key = self.config['OpenAI']['api_key']
base_url = self.config['OpenAI']['base_url']
elif chat_inference_provider == 'together':
api_key = self.config['Together']['api_key']
base_url = self.config['Together']['base_url']
if not utilities.is_api_key_valid(api_key=api_key, base_url=base_url):
return None

with duration.Duration(name='OpenAI Chat Completion Selected', screen=False):
Expand Down Expand Up @@ -305,7 +325,7 @@ def __init__(self,
self.llm_client = openai.OpenAI(api_key=api_key,
base_url=base_url)
self.model = self.config['Together']['ai_model']
print(f'[INFO] Using Together for inference. Model: {self.model}')
print(f'[INFO] Using Together AI for inference. Model: {self.model}')
super().__init__(config=self.config,
convo=convo,
save_to_file=save_to_file,
Expand Down
4 changes: 4 additions & 0 deletions tsutils/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,10 @@ def get_available_models(client: openai.OpenAI) -> list:
def is_api_key_valid(api_key: str, base_url) -> bool:
"""Check if it is valid openai compatible openai key for the provider
"""
if base_url == 'https://api.together.xyz':
# Together does not support the call client.models.list()
return True

openai.api_key = api_key

client = openai.OpenAI(api_key=api_key, base_url=base_url)
Expand Down

0 comments on commit 8128e17

Please sign in to comment.