Skip to content

Commit

Permalink
Add image generation to Home tab
Browse files Browse the repository at this point in the history
  • Loading branch information
seratch committed May 22, 2024
1 parent 11a88ce commit 75fcc34
Show file tree
Hide file tree
Showing 8 changed files with 312 additions and 16 deletions.
100 changes: 98 additions & 2 deletions app/bolt_listeners.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,13 @@
SYSTEM_TEXT,
TRANSLATE_MARKDOWN,
IMAGE_FILE_ACCESS_ENABLED,
OPENAI_IMAGE_GENERATION_MODEL,
)
from app.i18n import translate
from app.image_ops import append_image_content_if_exists
from app.openai_image_ops import (
append_image_content_if_exists,
generate_image,
)
from app.openai_ops import (
start_receiving_openai_response,
format_openai_message_content,
Expand Down Expand Up @@ -56,6 +60,10 @@
build_from_scratch_result_modal,
build_from_scratch_timeout_modal,
build_from_scratch_error_modal,
build_image_generation_input_modal,
build_image_generation_wip_modal,
build_image_generation_result_modal,
build_image_generation_text_modal,
)


Expand Down Expand Up @@ -197,6 +205,7 @@ def respond_to_app_mention(
openai_api_base=context["OPENAI_API_BASE"],
openai_api_version=context["OPENAI_API_VERSION"],
openai_deployment_id=context["OPENAI_DEPLOYMENT_ID"],
openai_organization_id=context["OPENAI_ORG_ID"],
function_call_module_name=context["OPENAI_FUNCTION_CALL_MODULE_NAME"],
)
consume_openai_stream_to_write_reply(
Expand Down Expand Up @@ -435,6 +444,7 @@ def respond_to_new_message(
openai_api_base=context["OPENAI_API_BASE"],
openai_api_version=context["OPENAI_API_VERSION"],
openai_deployment_id=context["OPENAI_DEPLOYMENT_ID"],
openai_organization_id=context["OPENAI_ORG_ID"],
function_call_module_name=context["OPENAI_FUNCTION_CALL_MODULE_NAME"],
)

Expand Down Expand Up @@ -668,7 +678,7 @@ def display_proofreading_result(
logger.exception(f"Failed to share a proofreading result: {e}")
client.views_update(
view_id=payload["id"],
view=build_proofreading_error_modal(payload=payload, text=text),
view=build_proofreading_error_modal(payload=payload, text=text, e=e),
)


Expand Down Expand Up @@ -712,6 +722,82 @@ def send_proofreading_result_in_dm(
logger.exception(f"Failed to send a DM: {e}")


#
# Image generation
#


def start_image_generation(client: WebClient, body: dict, payload: dict):
client.views_open(
trigger_id=body.get("trigger_id"),
view=build_image_generation_input_modal(payload.get("value")),
)


def ack_image_generation_modal_submission(ack: Ack):
ack(response_action="update", view=build_image_generation_wip_modal())


def display_image_generation_result(
client: WebClient,
context: BoltContext,
logger: logging.Logger,
payload: dict,
):
text = ""
try:
prompt = extract_state_value(payload, "image_generation_prompt").get("value")
size = extract_state_value(payload, "size").get("selected_option").get("value")
quality = (
extract_state_value(payload, "quality").get("selected_option").get("value")
)
style = (
extract_state_value(payload, "style").get("selected_option").get("value")
)

start_time = time.time()
image_url = generate_image(
context=context,
prompt=prompt,
size=size,
quality=quality,
style=style,
timeout_seconds=OPENAI_TIMEOUT_SECONDS,
)
spent_seconds = time.time() - start_time
logger.debug(
f"Image generated (url: {image_url} , spent time: {spent_seconds})"
)
model = context.get(
"OPENAI_IMAGE_GENERATION_MODEL", OPENAI_IMAGE_GENERATION_MODEL
)
view = build_image_generation_result_modal(
prompt=prompt,
spent_seconds=str(round(spent_seconds, 2)),
image_url=image_url,
model=model,
size=size,
quality=quality,
style=style,
)
client.views_update(view_id=payload["id"], view=view)

except (APITimeoutError, TimeoutError):
client.views_update(
view_id=payload["id"],
view=build_image_generation_text_modal(TIMEOUT_ERROR_MESSAGE),
)
except Exception as e:
logger.exception(f"Failed to share a generated image: {e}")
client.views_update(
view_id=payload["id"],
view=build_image_generation_text_modal(
f"{text}\n\n:warning: My apologies! "
f"An error occurred while generating an image: {e}"
),
)


#
# Chat from scratch
#
Expand Down Expand Up @@ -811,6 +897,16 @@ def attach_bot_scopes(client: WebClient, context: BoltContext, next_):
lazy=[send_proofreading_result_in_dm],
)

# Image generation
app.action("templates-image-generation")(
ack=just_ack,
lazy=[start_image_generation],
)
app.view("image-generation")(
ack=ack_image_generation_modal_submission,
lazy=[display_image_generation_result],
)

# Free format chat
app.action("templates-from-scratch")(
ack=just_ack,
Expand Down
8 changes: 8 additions & 0 deletions app/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,11 @@
DEFAULT_OPENAI_MODEL = "gpt-3.5-turbo"
OPENAI_MODEL = os.environ.get("OPENAI_MODEL", DEFAULT_OPENAI_MODEL)

DEFAULT_OPENAI_IMAGE_GENERATION_MODEL = "dall-e-3"
OPENAI_IMAGE_GENERATION_MODEL = os.environ.get(
"OPENAI_IMAGE_GENERATION_MODEL", DEFAULT_OPENAI_IMAGE_GENERATION_MODEL
)

DEFAULT_OPENAI_TEMPERATURE = 1
OPENAI_TEMPERATURE = float(
os.environ.get("OPENAI_TEMPERATURE", DEFAULT_OPENAI_TEMPERATURE)
Expand All @@ -36,6 +41,9 @@
"OPENAI_DEPLOYMENT_ID", DEFAULT_OPENAI_DEPLOYMENT_ID
)

DEFAULT_OPENAI_ORG_ID = None
OPENAI_ORG_ID = os.environ.get("OPENAI_ORG_ID", DEFAULT_OPENAI_ORG_ID)

DEFAULT_OPENAI_FUNCTION_CALL_MODULE_NAME = None
OPENAI_FUNCTION_CALL_MODULE_NAME = os.environ.get(
"OPENAI_FUNCTION_CALL_MODULE_NAME", DEFAULT_OPENAI_FUNCTION_CALL_MODULE_NAME
Expand Down
26 changes: 25 additions & 1 deletion app/image_ops.py → app/openai_image_ops.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import logging
from typing import List, Tuple
from typing import List, Tuple, Literal

import base64
from io import BytesIO
from PIL import Image

from app.openai_ops import create_openai_client
from app.slack_ops import download_slack_image_content
from slack_bolt import BoltContext


SUPPORTED_IMAGE_FORMATS = ["jpeg", "png", "gif"]
Expand Down Expand Up @@ -52,3 +54,25 @@ def encode_image_and_guess_format(image_data: bytes) -> Tuple[str, str]:

base64encoded_image_data = base64.b64encode(image_data).decode("utf-8")
return base64encoded_image_data, image_format


def generate_image(
*,
context: BoltContext,
prompt: str,
size: Literal["1024x1024", "1792x1024", "1024x1792"] = "1024x1024",
quality: Literal["standard", "hd"] = "standard",
style: Literal["vivid", "natural"] = "vivid",
timeout_seconds: int,
) -> str:
client = create_openai_client(context)
response = client.images.generate(
model=context["OPENAI_IMAGE_GENERATION_MODEL"],
prompt=prompt,
size=size,
quality=quality,
style=style,
timeout=timeout_seconds,
n=1,
)
return response.data[0].url
36 changes: 24 additions & 12 deletions app/openai_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ def make_synchronous_openai_call(
openai_api_base: str,
openai_api_version: str,
openai_deployment_id: str,
openai_organization_id: Optional[str],
timeout_seconds: int,
) -> Completion:
if openai_api_type == "azure":
Expand All @@ -119,6 +120,7 @@ def make_synchronous_openai_call(
client = OpenAI(
api_key=openai_api_key,
base_url=openai_api_base,
organization=openai_organization_id,
)
return client.chat.completions.create(
model=model,
Expand Down Expand Up @@ -147,6 +149,7 @@ def start_receiving_openai_response(
openai_api_base: str,
openai_api_version: str,
openai_deployment_id: str,
openai_organization_id: Optional[str],
function_call_module_name: Optional[str],
) -> Stream[Completion]:
kwargs = {}
Expand All @@ -163,6 +166,7 @@ def start_receiving_openai_response(
client = OpenAI(
api_key=openai_api_key,
base_url=openai_api_base,
organization=openai_organization_id,
)
return client.chat.completions.create(
model=model,
Expand Down Expand Up @@ -274,6 +278,7 @@ def update_message():
openai_api_base=context.get("OPENAI_API_BASE"),
openai_api_version=context.get("OPENAI_API_VERSION"),
openai_deployment_id=context.get("OPENAI_DEPLOYMENT_ID"),
openai_organization_id=context["OPENAI_ORG_ID"],
function_call_module_name=function_call_module_name,
)
consume_openai_stream_to_write_reply(
Expand Down Expand Up @@ -494,18 +499,7 @@ def calculate_tokens_necessary_for_function_call(context: BoltContext) -> int:
return _prompt_tokens_used_by_function_call_cache

def _calculate_prompt_tokens(functions) -> int:
if context.get("OPENAI_API_TYPE") == "azure":
client = AzureOpenAI(
api_key=context.get("OPENAI_API_KEY"),
api_version=context.get("OPENAI_API_VERSION"),
azure_endpoint=context.get("OPENAI_API_BASE"),
azure_deployment=context.get("OPENAI_DEPLOYMENT_ID"),
)
else:
client = OpenAI(
api_key=context.get("OPENAI_API_KEY"),
base_url=context.get("OPENAI_API_BASE"),
)
client = create_openai_client(context)
return client.chat.completions.create(
model=context.get("OPENAI_MODEL"),
messages=[{"role": "user", "content": "hello"}],
Expand Down Expand Up @@ -559,6 +553,7 @@ def generate_slack_thread_summary(
openai_api_base=context["OPENAI_API_BASE"],
openai_api_version=context["OPENAI_API_VERSION"],
openai_deployment_id=context["OPENAI_DEPLOYMENT_ID"],
openai_organization_id=context["OPENAI_ORG_ID"],
timeout_seconds=timeout_seconds,
)
spent_time = time.time() - start_time
Expand Down Expand Up @@ -611,6 +606,7 @@ def generate_proofreading_result(
openai_api_base=context["OPENAI_API_BASE"],
openai_api_version=context["OPENAI_API_VERSION"],
openai_deployment_id=context["OPENAI_DEPLOYMENT_ID"],
openai_organization_id=context["OPENAI_ORG_ID"],
timeout_seconds=timeout_seconds,
)
spent_time = time.time() - start_time
Expand Down Expand Up @@ -649,8 +645,24 @@ def generate_chatgpt_response(
openai_api_base=context["OPENAI_API_BASE"],
openai_api_version=context["OPENAI_API_VERSION"],
openai_deployment_id=context["OPENAI_DEPLOYMENT_ID"],
openai_organization_id=context["OPENAI_ORG_ID"],
timeout_seconds=timeout_seconds,
)
spent_time = time.time() - start_time
logger.debug(f"Proofreading took {spent_time} seconds")
return openai_response.model_dump()["choices"][0]["message"]["content"]


def create_openai_client(context: BoltContext) -> OpenAI | AzureOpenAI:
if context.get("OPENAI_API_TYPE") == "azure":
return AzureOpenAI(
api_key=context.get("OPENAI_API_KEY"),
api_version=context.get("OPENAI_API_VERSION"),
azure_endpoint=context.get("OPENAI_API_BASE"),
azure_deployment=context.get("OPENAI_DEPLOYMENT_ID"),
)
else:
return OpenAI(
api_key=context.get("OPENAI_API_KEY"),
base_url=context.get("OPENAI_API_BASE"),
)
Loading

0 comments on commit 75fcc34

Please sign in to comment.