Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(ui,api): endpoint for bulk image uploads #7159

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 86 additions & 2 deletions invokeai/app/api/routers/images.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import io
import traceback
from typing import Optional
from typing import List, Optional

from fastapi import BackgroundTasks, Body, HTTPException, Path, Query, Request, Response, UploadFile
from fastapi.responses import FileResponse
Expand All @@ -15,7 +15,7 @@
ImageRecordChanges,
ResourceOrigin,
)
from invokeai.app.services.images.images_common import ImageDTO, ImageUrlsDTO
from invokeai.app.services.images.images_common import ImageBulkUploadData, ImageDTO, ImageUrlsDTO
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection

Expand All @@ -26,6 +26,90 @@
IMAGE_MAX_AGE = 31536000


class BulkUploadImageResponse(BaseModel):
sent: int
uploading: int


@images_router.post(
"/bulk-upload",
operation_id="bulk_upload",
responses={
201: {"description": "The images are being prepared for upload"},
415: {"description": "Images upload failed"},
},
status_code=201,
response_model=BulkUploadImageResponse,
)
async def bulk_upload(
bulk_upload_id: str,
files: list[UploadFile],
background_tasks: BackgroundTasks,
request: Request,
response: Response,
board_id: Optional[str] = Query(default=None, description="The board to add this images to, if any"),
) -> BulkUploadImageResponse:
"""Uploads multiple images"""
upload_data_list: List[ImageBulkUploadData] = []

# loop to handle multiple files
for file in files:
if not file.content_type or not file.content_type.startswith("image"):
ApiDependencies.invoker.services.logger.error("Not an image")
continue

_metadata = None
_workflow = None
_graph = None

contents = await file.read()
try:
pil_image = Image.open(io.BytesIO(contents))
except Exception:
ApiDependencies.invoker.services.logger.error(traceback.format_exc())
continue

# TODO: retain non-invokeai metadata on upload?
# attempt to parse metadata from image
metadata_raw = pil_image.info.get("invokeai_metadata", None)
if isinstance(metadata_raw, str):
_metadata = metadata_raw
else:
ApiDependencies.invoker.services.logger.debug("Failed to parse metadata for uploaded image")
pass

# attempt to parse workflow from image
workflow_raw = pil_image.info.get("invokeai_workflow", None)
if isinstance(workflow_raw, str):
_workflow = workflow_raw
else:
ApiDependencies.invoker.services.logger.debug("Failed to parse workflow for uploaded image")
pass

# attempt to extract graph from image
graph_raw = pil_image.info.get("invokeai_graph", None)
if isinstance(graph_raw, str):
_graph = graph_raw
else:
ApiDependencies.invoker.services.logger.debug("Failed to parse graph for uploaded image")
pass

# construct an ImageUploadData object for each file
upload_data = ImageBulkUploadData(
image=pil_image,
board_id=board_id,
metadata=_metadata,
workflow=_workflow,
graph=_graph,
)
upload_data_list.append(upload_data)

# Schedule image processing as a background task
background_tasks.add_task(ApiDependencies.invoker.services.images.create_many, bulk_upload_id, upload_data_list)

return BulkUploadImageResponse(sent=len(files), uploading=len(upload_data_list))


@images_router.post(
"/upload",
operation_id="upload_image",
Expand Down
28 changes: 28 additions & 0 deletions invokeai/app/api/sockets.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@
BulkDownloadErrorEvent,
BulkDownloadEventBase,
BulkDownloadStartedEvent,
BulkUploadCompletedEvent,
BulkUploadErrorEvent,
BulkUploadEventBase,
BulkUploadProgressEvent,
BulkUploadStartedEvent,
DownloadCancelledEvent,
DownloadCompleteEvent,
DownloadErrorEvent,
Expand Down Expand Up @@ -53,6 +58,13 @@ class BulkDownloadSubscriptionEvent(BaseModel):
bulk_download_id: str


class BulkUploadSubscriptionEvent(BaseModel):
"""Event data for subscribing to the socket.io bulk uploads room.
This is a pydantic model to ensure the data is in the correct format."""

bulk_upload_id: str


QUEUE_EVENTS = {
InvocationStartedEvent,
InvocationProgressEvent,
Expand Down Expand Up @@ -80,6 +92,7 @@ class BulkDownloadSubscriptionEvent(BaseModel):
}

BULK_DOWNLOAD_EVENTS = {BulkDownloadStartedEvent, BulkDownloadCompleteEvent, BulkDownloadErrorEvent}
BULK_UPLOAD_EVENTS = {BulkUploadStartedEvent, BulkUploadCompletedEvent, BulkUploadProgressEvent, BulkUploadErrorEvent}


class SocketIO:
Expand All @@ -89,6 +102,9 @@ class SocketIO:
_sub_bulk_download = "subscribe_bulk_download"
_unsub_bulk_download = "unsubscribe_bulk_download"

_sub_bulk_upload = "subscribe_bulk_upload"
_unsub_bulk_upload = "unsubscribe_bulk_upload"

def __init__(self, app: FastAPI):
self._sio = AsyncServer(async_mode="asgi", cors_allowed_origins="*")
self._app = ASGIApp(socketio_server=self._sio, socketio_path="/ws/socket.io")
Expand All @@ -98,10 +114,13 @@ def __init__(self, app: FastAPI):
self._sio.on(self._unsub_queue, handler=self._handle_unsub_queue)
self._sio.on(self._sub_bulk_download, handler=self._handle_sub_bulk_download)
self._sio.on(self._unsub_bulk_download, handler=self._handle_unsub_bulk_download)
self._sio.on(self._sub_bulk_upload, handler=self._handle_sub_bulk_upload)
self._sio.on(self._unsub_bulk_upload, handler=self._handle_unsub_bulk_upload)

register_events(QUEUE_EVENTS, self._handle_queue_event)
register_events(MODEL_EVENTS, self._handle_model_event)
register_events(BULK_DOWNLOAD_EVENTS, self._handle_bulk_image_download_event)
register_events(BULK_UPLOAD_EVENTS, self._handle_bulk_image_upload_event)

async def _handle_sub_queue(self, sid: str, data: Any) -> None:
await self._sio.enter_room(sid, QueueSubscriptionEvent(**data).queue_id)
Expand All @@ -115,6 +134,12 @@ async def _handle_sub_bulk_download(self, sid: str, data: Any) -> None:
async def _handle_unsub_bulk_download(self, sid: str, data: Any) -> None:
await self._sio.leave_room(sid, BulkDownloadSubscriptionEvent(**data).bulk_download_id)

async def _handle_sub_bulk_upload(self, sid: str, data: Any) -> None:
await self._sio.enter_room(sid, BulkUploadSubscriptionEvent(**data).bulk_upload_id)

async def _handle_unsub_bulk_upload(self, sid: str, data: Any) -> None:
await self._sio.leave_room(sid, BulkUploadSubscriptionEvent(**data).bulk_upload_id)

async def _handle_queue_event(self, event: FastAPIEvent[QueueEventBase]):
await self._sio.emit(event=event[0], data=event[1].model_dump(mode="json"), room=event[1].queue_id)

Expand All @@ -123,3 +148,6 @@ async def _handle_model_event(self, event: FastAPIEvent[ModelEventBase | Downloa

async def _handle_bulk_image_download_event(self, event: FastAPIEvent[BulkDownloadEventBase]) -> None:
await self._sio.emit(event=event[0], data=event[1].model_dump(mode="json"), room=event[1].bulk_download_id)

async def _handle_bulk_image_upload_event(self, event: FastAPIEvent[BulkUploadEventBase]) -> None:
await self._sio.emit(event=event[0], data=event[1].model_dump(mode="json"), room=event[1].bulk_upload_id)
29 changes: 28 additions & 1 deletion invokeai/app/services/events/events_base.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)


from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, Literal, Optional

from invokeai.app.services.events.events_common import (
BatchEnqueuedEvent,
BulkDownloadCompleteEvent,
BulkDownloadErrorEvent,
BulkDownloadStartedEvent,
BulkUploadCompletedEvent,
BulkUploadErrorEvent,
BulkUploadProgressEvent,
BulkUploadStartedEvent,
DownloadCancelledEvent,
DownloadCompleteEvent,
DownloadErrorEvent,
Expand All @@ -30,6 +34,7 @@
QueueClearedEvent,
QueueItemStatusChangedEvent,
)
from invokeai.app.services.images.images_common import ImageDTO

if TYPE_CHECKING:
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput
Expand All @@ -44,6 +49,8 @@
)
from invokeai.backend.model_manager.config import AnyModelConfig, SubModelType

UploadStatusType = Literal["started", "processing", "done", "error"]


class EventServiceBase:
"""Basic event bus, to have an empty stand-in when not needed"""
Expand Down Expand Up @@ -197,3 +204,23 @@ def emit_bulk_download_error(
)

# endregion

# region Bulk image upload

def emit_bulk_upload_started(self, bulk_upload_id: str, total: int) -> None:
"""Emitted when a bulk image upload is started"""
self.dispatch(BulkUploadStartedEvent.build(bulk_upload_id, total))

def emit_bulk_upload_progress(self, bulk_upload_id: str, completed: int, total: int) -> None:
"""Emitted when a bulk image upload is started"""
self.dispatch(BulkUploadProgressEvent.build(bulk_upload_id, completed, total))

def emit_bulk_upload_complete(self, bulk_upload_id: str, total: int, image_DTO: ImageDTO) -> None:
"""Emitted when a bulk image upload is complete"""
self.dispatch(BulkUploadCompletedEvent.build(bulk_upload_id, total=total, image_DTO=image_DTO))

def emit_bulk_upload_error(self, bulk_upload_id: str, error: str) -> None:
"""Emitted when a bulk image upload has an error"""
self.dispatch(BulkUploadErrorEvent.build(bulk_upload_id, error))

# endregion
78 changes: 78 additions & 0 deletions invokeai/app/services/events/events_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from fastapi_events.registry.payload_schema import registry as payload_schema
from pydantic import BaseModel, ConfigDict, Field

from invokeai.app.services.images.images_common import ImageDTO
from invokeai.app.services.session_processor.session_processor_common import ProgressImage
from invokeai.app.services.session_queue.session_queue_common import (
QUEUE_ITEM_STATUS,
Expand Down Expand Up @@ -624,3 +625,80 @@ def build(
bulk_download_item_name=bulk_download_item_name,
error=error,
)


class BulkUploadEventBase(EventBase):
"""Base class for events associated with a bulk image upload"""

bulk_upload_id: str = Field(description="The ID of the bulk image download")


@payload_schema.register
class BulkUploadStartedEvent(BulkUploadEventBase):
"""Event model for bulk_upload_started"""

__event_name__ = "bulk_upload_started"

total: int = Field(description="The total numberof images")

@classmethod
def build(
cls,
bulk_upload_id: str,
total: int,
) -> "BulkUploadStartedEvent":
return cls(bulk_upload_id=bulk_upload_id, total=total)


@payload_schema.register
class BulkUploadCompletedEvent(BulkUploadEventBase):
"""Event model for bulk_upload_completed"""

__event_name__ = "bulk_upload_completed"

total: int = Field(description="The total numberof images")
image_DTO: ImageDTO = Field(description="An image from the upload so client can refetch correctly")

@classmethod
def build(cls, bulk_upload_id: str, total: int, image_DTO: ImageDTO) -> "BulkUploadCompletedEvent":
return cls(bulk_upload_id=bulk_upload_id, total=total, image_DTO=image_DTO)


@payload_schema.register
class BulkUploadProgressEvent(BulkUploadEventBase):
"""Event model for bulk_upload_progress"""

__event_name__ = "bulk_upload_progress"

completed: int = Field(description="The completed number of images")
total: int = Field(description="The total number of images")

@classmethod
def build(
cls,
bulk_upload_id: str,
completed: int,
total: int,
) -> "BulkUploadProgressEvent":
return cls(
bulk_upload_id=bulk_upload_id,
completed=completed,
total=total,
)


@payload_schema.register
class BulkUploadErrorEvent(BulkUploadEventBase):
"""Event model for bulk_upload_error"""

__event_name__ = "bulk_upload_error"

error: str = Field(description="The error message")

@classmethod
def build(
cls,
bulk_upload_id: str,
error: str,
) -> "BulkUploadErrorEvent":
return cls(bulk_upload_id=bulk_upload_id, error=error)
1 change: 1 addition & 0 deletions invokeai/app/services/image_files/image_files_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def save(
workflow: Optional[str] = None,
graph: Optional[str] = None,
thumbnail_size: int = 256,
project_id: Optional[str] = None,
) -> None:
"""Saves an image and a 256x256 WEBP thumbnail. Returns a tuple of the image name, thumbnail name, and created timestamp."""
pass
Expand Down
1 change: 1 addition & 0 deletions invokeai/app/services/image_files/image_files_disk.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def save(
workflow: Optional[str] = None,
graph: Optional[str] = None,
thumbnail_size: int = 256,
project_id: Optional[str] = None,
) -> None:
try:
self.__validate_storage_folders()
Expand Down
2 changes: 2 additions & 0 deletions invokeai/app/services/image_records/image_records_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,8 @@ def save(
session_id: Optional[str] = None,
node_id: Optional[str] = None,
metadata: Optional[str] = None,
user_id: Optional[str] = None,
project_id: Optional[str] = None,
) -> datetime:
"""Saves an image record."""
pass
Expand Down
2 changes: 2 additions & 0 deletions invokeai/app/services/image_records/image_records_sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,8 @@ def save(
session_id: Optional[str] = None,
node_id: Optional[str] = None,
metadata: Optional[str] = None,
user_id: Optional[str] = None,
project_id: Optional[str] = None,
) -> datetime:
try:
self._lock.acquire()
Expand Down
9 changes: 7 additions & 2 deletions invokeai/app/services/images/images_base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import Callable, Optional
from typing import Callable, List, Optional

from PIL.Image import Image as PILImageType

Expand All @@ -10,7 +10,7 @@
ImageRecordChanges,
ResourceOrigin,
)
from invokeai.app.services.images.images_common import ImageDTO
from invokeai.app.services.images.images_common import ImageBulkUploadData, ImageDTO
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection

Expand Down Expand Up @@ -58,6 +58,11 @@ def create(
"""Creates an image, storing the file and its metadata."""
pass

@abstractmethod
def create_many(self, bulk_upload_id: str, upload_data_list: List[ImageBulkUploadData]):
"""Creates an images array DTO out of an array of images, storing the images and their metadata"""
pass

@abstractmethod
def update(
self,
Expand Down
Loading