Skip to content

Commit

Permalink
Merge branch 'main' into vision-support
Browse files Browse the repository at this point in the history
  • Loading branch information
gilcu3 committed Nov 18, 2023
2 parents 1ec7ae1 + b8a82d8 commit 28ec9f9
Show file tree
Hide file tree
Showing 8 changed files with 348 additions and 49 deletions.
10 changes: 9 additions & 1 deletion .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,12 @@ ALLOWED_TELEGRAM_USER_IDS=USER_ID_1,USER_ID_2
# VISION_TOKEN_PRICE=0.01
# ENABLE_QUOTING=true
# ENABLE_IMAGE_GENERATION=true
# ENABLE_TTS_GENERATION=true
# ENABLE_TRANSCRIPTION=true
# ENABLE_VISION=true
# PROXY=http://localhost:8080
# OPENAI_MODEL=gpt-3.5-turbo
# OPENAI_BASE_URL=https://example.com/v1/
# ASSISTANT_PROMPT="You are a helpful assistant."
# SHOW_USAGE=false
# STREAM=true
Expand All @@ -38,9 +40,15 @@ ALLOWED_TELEGRAM_USER_IDS=USER_ID_1,USER_ID_2
# TEMPERATURE=1.0
# PRESENCE_PENALTY=0.0
# FREQUENCY_PENALTY=0.0
# IMAGE_SIZE=512x512
# IMAGE_MODEL=dall-e-3
# IMAGE_QUALITY=hd
# IMAGE_STYLE=natural
# IMAGE_SIZE=1024x1024
# IMAGE_FORMAT=document
# VISION_DETAIL="low"
# GROUP_TRIGGER_KEYWORD=""
# IGNORE_GROUP_TRANSCRIPTIONS=true
# IGNORE_GROUP_VISION=true
# TTS_MODEL="tts-1"
# TTS_VOICE="alloy"
# BOT_LANGUAGE=en
65 changes: 38 additions & 27 deletions README.md

Large diffs are not rendered by default.

9 changes: 9 additions & 0 deletions bot/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ def main():
'max_tokens': int(os.environ.get('MAX_TOKENS', max_tokens_default)),
'n_choices': int(os.environ.get('N_CHOICES', 1)),
'temperature': float(os.environ.get('TEMPERATURE', 1.0)),
'image_model': os.environ.get('IMAGE_MODEL', 'dall-e-2'),
'image_quality': os.environ.get('IMAGE_QUALITY', 'standard'),
'image_style': os.environ.get('IMAGE_STYLE', 'vivid'),
'image_size': os.environ.get('IMAGE_SIZE', '512x512'),
'model': model,
'enable_functions': os.environ.get('ENABLE_FUNCTIONS', str(functions_available)).lower() == 'true',
Expand All @@ -53,6 +56,8 @@ def main():
'vision_prompt': os.environ.get('VISION_PROMPT', 'What is in this image'),
'vision_detail': os.environ.get('VISION_DETAIL', 'low'),
'vision_max_tokens': int(os.environ.get('VISION_MAX_TOKENS', '300')),
'tts_model': os.environ.get('TTS_MODEL', 'tts-1'),
'tts_voice': os.environ.get('TTS_VOICE', 'alloy'),
}

if openai_config['enable_functions'] and not functions_available:
Expand All @@ -74,6 +79,7 @@ def main():
'enable_image_generation': os.environ.get('ENABLE_IMAGE_GENERATION', 'true').lower() == 'true',
'enable_transcription': os.environ.get('ENABLE_TRANSCRIPTION', 'true').lower() == 'true',
'enable_vision': os.environ.get('ENABLE_VISION', 'true').lower() == 'true',
'enable_tts_generation': os.environ.get('ENABLE_TTS_GENERATION', 'true').lower() == 'true',
'budget_period': os.environ.get('BUDGET_PERIOD', 'monthly').lower(),
'user_budgets': os.environ.get('USER_BUDGETS', os.environ.get('MONTHLY_USER_BUDGETS', '*')),
'guest_budget': float(os.environ.get('GUEST_BUDGET', os.environ.get('MONTHLY_GUEST_BUDGET', '100.0'))),
Expand All @@ -87,6 +93,9 @@ def main():
'token_price': float(os.environ.get('TOKEN_PRICE', 0.002)),
'image_prices': [float(i) for i in os.environ.get('IMAGE_PRICES', "0.016,0.018,0.02").split(",")],
'vision_token_price': float(os.environ.get('VISION_TOKEN_PRICE', '0.01')),
'image_receive_mode': os.environ.get('IMAGE_FORMAT', "photo"),
'tts_model': os.environ.get('TTS_MODEL', 'tts-1'),
'tts_prices': [float(i) for i in os.environ.get('TTS_PRICES', "0.015,0.030").split(",")],
'transcription_price': float(os.environ.get('TRANSCRIPTION_PRICE', 0.006)),
'bot_language': os.environ.get('BOT_LANGUAGE', 'en'),
}
Expand Down
51 changes: 42 additions & 9 deletions bot/openai_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import requests
import json
import httpx
import io
from datetime import date
from calendar import monthrange
from PIL import Image
Expand All @@ -21,11 +22,12 @@

# Models can be found here: https://platform.openai.com/docs/models/overview
GPT_3_MODELS = ("gpt-3.5-turbo", "gpt-3.5-turbo-0301", "gpt-3.5-turbo-0613")
GPT_3_16K_MODELS = ("gpt-3.5-turbo-16k", "gpt-3.5-turbo-16k-0613")
GPT_3_16K_MODELS = ("gpt-3.5-turbo-16k", "gpt-3.5-turbo-16k-0613", "gpt-3.5-turbo-1106")
GPT_4_MODELS = ("gpt-4", "gpt-4-0314", "gpt-4-0613")
GPT_4_32K_MODELS = ("gpt-4-32k", "gpt-4-32k-0314", "gpt-4-32k-0613")
GPT_4_VISION_MODELS = ("gpt-4-vision-preview",)
GPT_ALL_MODELS = GPT_3_MODELS + GPT_3_16K_MODELS + GPT_4_MODELS + GPT_4_32K_MODELS + GPT_4_VISION_MODELS
GPT_4_128K_MODELS = ("gpt-4-1106-preview",)
GPT_ALL_MODELS = GPT_3_MODELS + GPT_3_16K_MODELS + GPT_4_MODELS + GPT_4_32K_MODELS + GPT_4_VISION_MODELS + GPT_4_128K_MODELS


def default_max_tokens(model: str) -> int:
Expand All @@ -39,12 +41,16 @@ def default_max_tokens(model: str) -> int:
return base
elif model in GPT_4_MODELS:
return base * 2
elif model in GPT_3_16K_MODELS:
elif model in GPT_3_16K_MODELS:
if model == "gpt-3.5-turbo-1106":
return 4096
return base * 4
elif model in GPT_4_32K_MODELS:
return base * 8
elif model in GPT_4_VISION_MODELS:
return 4096
elif model in GPT_4_128K_MODELS:
return 4096


def are_functions_available(model: str) -> bool:
Expand All @@ -55,7 +61,7 @@ def are_functions_available(model: str) -> bool:
if model in ("gpt-3.5-turbo-0301", "gpt-4-0314", "gpt-4-32k-0314"):
return False
# Stable models will be updated to support functions on June 27, 2023
if model in ("gpt-3.5-turbo", "gpt-4", "gpt-4-32k"):
if model in ("gpt-3.5-turbo", "gpt-3.5-turbo-1106", "gpt-4", "gpt-4-32k","gpt-4-1106-preview"):
return datetime.date.today() > datetime.date(2023, 6, 27)
if model == 'gpt-4-vision-preview':
return False
Expand Down Expand Up @@ -326,6 +332,9 @@ async def generate_image(self, prompt: str) -> tuple[str, str]:
response = await self.client.images.generate(
prompt=prompt,
n=1,
model=self.config['image_model'],
quality=self.config['image_quality'],
style=self.config['image_style'],
size=self.config['image_size']
)

Expand All @@ -340,6 +349,28 @@ async def generate_image(self, prompt: str) -> tuple[str, str]:
except Exception as e:
raise Exception(f"⚠️ _{localized_text('error', bot_language)}._ ⚠️\n{str(e)}") from e

async def generate_speech(self, text: str) -> tuple[any, int]:
"""
Generates an audio from the given text using TTS model.
:param prompt: The text to send to the model
:return: The audio in bytes and the text size
"""
bot_language = self.config['bot_language']
try:
response = await self.client.audio.speech.create(
model=self.config['tts_model'],
voice=self.config['tts_voice'],
input=text,
response_format='opus'
)

temp_file = io.BytesIO()
temp_file.write(response.read())
temp_file.seek(0)
return temp_file, len(text)
except Exception as e:
raise Exception(f"⚠️ _{localized_text('error', bot_language)}._ ⚠️\n{str(e)}") from e

async def transcribe(self, filename):
"""
Transcribes the audio file using the Whisper model.
Expand Down Expand Up @@ -372,7 +403,7 @@ async def interpret_image(self, chat_id, fileobj, prompt=None):
message = {'role':'user', 'content':[{'type':'text', 'text':prompt}, {'type':'image_url', \
'image_url': {'url':f'data:image/jpeg;base64,{image}', 'detail':self.config['vision_detail'] } }]}
common_args = {
'model': self.config['model'],
'model': 'gpt-4-vision-preview', # the only one that currently makes sense here
'messages': self.conversations[chat_id] + [message],
'temperature': self.config['temperature'],
'n': 1, # several choices is not implemented yet
Expand Down Expand Up @@ -451,7 +482,7 @@ async def __summarise(self, conversation) -> str:
messages=messages,
temperature=0.4
)
return response.choices[0]['message']['content']
return response.choices[0].message.content

def __max_model_tokens(self):
base = 4096
Expand All @@ -465,6 +496,8 @@ def __max_model_tokens(self):
return base * 8
if self.config['model'] in GPT_4_VISION_MODELS:
return base * 31
if self.config['model'] in GPT_4_128K_MODELS:
return base * 31
raise NotImplementedError(
f"Max tokens for model {self.config['model']} is not implemented yet."
)
Expand All @@ -485,7 +518,7 @@ def __count_tokens(self, messages) -> int:
if model in GPT_3_MODELS + GPT_3_16K_MODELS:
tokens_per_message = 4 # every message follows <|start|>{role/name}\n{content}<|end|>\n
tokens_per_name = -1 # if there's a name, the role is omitted
elif model in GPT_4_MODELS + GPT_4_32K_MODELS + GPT_4_VISION_MODELS:
elif model in GPT_4_MODELS + GPT_4_32K_MODELS + GPT_4_VISION_MODELS + GPT_4_128K_MODELS:
tokens_per_message = 3
tokens_per_name = 1
else:
Expand All @@ -507,9 +540,9 @@ def __count_tokens_vision(self, fileobj) -> int:
:return: the number of tokens required
"""
image = Image.open(fileobj)
model = self.config['model']
model = 'gpt-4-vision-preview' # fixed for now
if model not in GPT_4_VISION_MODELS:
raise NotImplementedError(f"""num_tokens_from_messages() is not implemented for model {model}.""")
raise NotImplementedError(f"""count_tokens_vision() is not implemented for model {model}.""")

w, h = image.size
if w > h: w, h = h, w
Expand Down
77 changes: 73 additions & 4 deletions bot/telegram_bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ def __init__(self, config: dict, openai: OpenAIHelper):
if self.config.get('enable_image_generation', False):
self.commands.append(BotCommand(command='image', description=localized_text('image_description', bot_language)))

if self.config.get('enable_tts_generation', False):
self.commands.append(BotCommand(command='tts', description=localized_text('tts_description', bot_language)))

self.group_commands = [BotCommand(
command='chat', description=localized_text('chat_description', bot_language)
)] + self.commands
Expand Down Expand Up @@ -97,6 +100,7 @@ async def stats(self, update: Update, context: ContextTypes.DEFAULT_TYPE):
(transcribe_minutes_today, transcribe_seconds_today, transcribe_minutes_month,
transcribe_seconds_month) = self.usage[user_id].get_current_transcription_duration()
vision_today, vision_month = self.usage[user_id].get_current_vision_tokens()
characters_today, characters_month = self.usage[user_id].get_current_tts_usage()
current_cost = self.usage[user_id].get_current_cost()

chat_id = update.effective_chat.id
Expand All @@ -119,12 +123,17 @@ async def stats(self, update: Update, context: ContextTypes.DEFAULT_TYPE):
text_today_vision = ""
if self.config.get('enable_vision', False):
text_today_vision = f"{vision_today} {localized_text('stats_vision', bot_language)}\n"

text_today_tts = ""
if self.config.get('enable_tts_generation', False):
text_today_tts = f"{characters_today} {localized_text('stats_tts', bot_language)}\n"

text_today = (
f"*{localized_text('usage_today', bot_language)}:*\n"
f"{tokens_today} {localized_text('stats_tokens', bot_language)}\n"
f"{text_today_images}" # Include the image statistics for today if applicable
f"{text_today_vision}"
f"{text_today_tts}"
f"{transcribe_minutes_today} {localized_text('stats_transcribe', bot_language)[0]} "
f"{transcribe_seconds_today} {localized_text('stats_transcribe', bot_language)[1]}\n"
f"{localized_text('stats_total', bot_language)}{current_cost['cost_today']:.2f}\n"
Expand All @@ -138,13 +147,18 @@ async def stats(self, update: Update, context: ContextTypes.DEFAULT_TYPE):
text_month_vision = ""
if self.config.get('enable_vision', False):
text_month_vision = f"{vision_month} {localized_text('stats_vision', bot_language)}\n"

text_month_tts = ""
if self.config.get('enable_tts_generation', False):
text_month_tts = f"{characters_month} {localized_text('stats_tts', bot_language)}\n"

# Check if image generation is enabled and, if so, generate the image statistics for the month
text_month = (
f"*{localized_text('usage_month', bot_language)}:*\n"
f"{tokens_month} {localized_text('stats_tokens', bot_language)}\n"
f"{text_month_images}" # Include the image statistics for the month if applicable
f"{text_month_vision}"
f"{text_month_tts}"
f"{transcribe_minutes_month} {localized_text('stats_transcribe', bot_language)[0]} "
f"{transcribe_seconds_month} {localized_text('stats_transcribe', bot_language)[1]}\n"
f"{localized_text('stats_total', bot_language)}{current_cost['cost_month']:.2f}"
Expand Down Expand Up @@ -241,10 +255,18 @@ async def image(self, update: Update, context: ContextTypes.DEFAULT_TYPE):
async def _generate():
try:
image_url, image_size = await self.openai.generate_image(prompt=image_query)
await update.effective_message.reply_photo(
reply_to_message_id=get_reply_to_message_id(self.config, update),
photo=image_url
)
if self.config['image_receive_mode'] == 'photo':
await update.effective_message.reply_photo(
reply_to_message_id=get_reply_to_message_id(self.config, update),
photo=image_url
)
elif self.config['image_receive_mode'] == 'document':
await update.effective_message.reply_document(
reply_to_message_id=get_reply_to_message_id(self.config, update),
document=image_url
)
else:
raise Exception(f"env variable IMAGE_RECEIVE_MODE has invalid value {self.config['image_receive_mode']}")
# add image request to users usage tracker
user_id = update.message.from_user.id
self.usage[user_id].add_image_request(image_size, self.config['image_prices'])
Expand All @@ -263,6 +285,52 @@ async def _generate():

await wrap_with_indicator(update, context, _generate, constants.ChatAction.UPLOAD_PHOTO)

async def tts(self, update: Update, context: ContextTypes.DEFAULT_TYPE):
"""
Generates an speech for the given input using TTS APIs
"""
if not self.config['enable_tts_generation'] \
or not await self.check_allowed_and_within_budget(update, context):
return

tts_query = message_text(update.message)
if tts_query == '':
await update.effective_message.reply_text(
message_thread_id=get_thread_id(update),
text=localized_text('tts_no_prompt', self.config['bot_language'])
)
return

logging.info(f'New speech generation request received from user {update.message.from_user.name} '
f'(id: {update.message.from_user.id})')

async def _generate():
try:
speech_file, text_length = await self.openai.generate_speech(text=tts_query)

await update.effective_message.reply_voice(
reply_to_message_id=get_reply_to_message_id(self.config, update),
voice=speech_file
)
speech_file.close()
# add image request to users usage tracker
user_id = update.message.from_user.id
self.usage[user_id].add_tts_request(text_length, self.config['tts_model'], self.config['tts_prices'])
# add guest chat request to guest usage tracker
if str(user_id) not in self.config['allowed_user_ids'].split(',') and 'guests' in self.usage:
self.usage["guests"].add_tts_request(text_length, self.config['tts_model'], self.config['tts_prices'])

except Exception as e:
logging.exception(e)
await update.effective_message.reply_text(
message_thread_id=get_thread_id(update),
reply_to_message_id=get_reply_to_message_id(self.config, update),
text=f"{localized_text('tts_fail', self.config['bot_language'])}: {str(e)}",
parse_mode=constants.ParseMode.MARKDOWN
)

await wrap_with_indicator(update, context, _generate, constants.ChatAction.UPLOAD_VOICE)

async def transcribe(self, update: Update, context: ContextTypes.DEFAULT_TYPE):
"""
Transcribe audio messages.
Expand Down Expand Up @@ -909,6 +977,7 @@ def run(self):
application.add_handler(CommandHandler('reset', self.reset))
application.add_handler(CommandHandler('help', self.help))
application.add_handler(CommandHandler('image', self.image))
application.add_handler(CommandHandler('tts', self.tts))
application.add_handler(CommandHandler('start', self.help))
application.add_handler(CommandHandler('stats', self.stats))
application.add_handler(CommandHandler('resend', self.resend))
Expand Down
Loading

0 comments on commit 28ec9f9

Please sign in to comment.