Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/inpainting pipeline #250

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
138 changes: 138 additions & 0 deletions cmd/examples/inpainting/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
// cmd/examples/inpainting/main.go

package main

import (
"bytes"
"fmt"
"io"
"log/slog"
"mime/multipart"
"net/http"
"os"
"path"
"strconv"
)

func main() {
if len(os.Args) < 4 {
slog.Error("Usage: main <runs> <prompt> <image_path> <mask_path>")
return
}

runs, err := strconv.Atoi(os.Args[1])
if err != nil {
slog.Error("Invalid runs arg", slog.String("error", err.Error()))
return
}

prompt := os.Args[2]
imagePath := os.Args[3]
maskPath := os.Args[4]

// Create output directory if it doesn't exist
outputDir := "output"
if err := os.MkdirAll(outputDir, 0755); err != nil {
slog.Error("Error creating output directory", slog.String("error", err.Error()))
return
}

// Prepare request URL
url := "http://localhost:8600/inpainting"

for i := 0; i < runs; i++ {
slog.Info("Running inpainting", slog.Int("run", i+1))

// Create multipart form data
var b bytes.Buffer
w := multipart.NewWriter(&b)

// Add prompt
if err := w.WriteField("prompt", prompt); err != nil {
slog.Error("Error writing prompt field", slog.String("error", err.Error()))
return
}

// Add image file
imageFile, err := os.Open(imagePath)
if err != nil {
slog.Error("Error opening image file", slog.String("error", err.Error()))
return
}
defer imageFile.Close()

fw, err := w.CreateFormFile("image", imagePath)
if err != nil {
slog.Error("Error creating form file", slog.String("error", err.Error()))
return
}
if _, err = io.Copy(fw, imageFile); err != nil {
slog.Error("Error copying image file", slog.String("error", err.Error()))
return
}

// Add mask file
maskFile, err := os.Open(maskPath)
if err != nil {
slog.Error("Error opening mask file", slog.String("error", err.Error()))
return
}
defer maskFile.Close()

fw, err = w.CreateFormFile("mask_image", maskPath)
if err != nil {
slog.Error("Error creating form file", slog.String("error", err.Error()))
return
}
if _, err = io.Copy(fw, maskFile); err != nil {
slog.Error("Error copying mask file", slog.String("error", err.Error()))
return
}

// Close the writer
w.Close()

// Create request
req, err := http.NewRequest("POST", url, &b)
if err != nil {
slog.Error("Error creating request", slog.String("error", err.Error()))
return
}
req.Header.Set("Content-Type", w.FormDataContentType())

// Send request
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
slog.Error("Error sending request", slog.String("error", err.Error()))
return
}
defer resp.Body.Close()

// Check response status
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
slog.Error("Error response from server",
slog.Int("status", resp.StatusCode),
slog.String("body", string(body)))
return
}

// Save response to file
outputPath := path.Join(outputDir, fmt.Sprintf("output_%d.json", i))
out, err := os.Create(outputPath)
if err != nil {
slog.Error("Error creating output file", slog.String("error", err.Error()))
return
}
defer out.Close()

_, err = io.Copy(out, resp.Body)
if err != nil {
slog.Error("Error saving response", slog.String("error", err.Error()))
return
}

slog.Info("Output written", slog.String("path", outputPath))
}
}
8 changes: 8 additions & 0 deletions runner/app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,10 @@ def load_pipeline(pipeline: str, model_id: str) -> any:
from app.pipelines.live_video_to_video import LiveVideoToVideoPipeline

return LiveVideoToVideoPipeline(model_id)
case "inpainting":
from app.pipelines.inpainting import InpaintingPipeline

return InpaintingPipeline(model_id)
case _:
raise EnvironmentError(
f"{pipeline} is not a valid pipeline for model {model_id}"
Expand Down Expand Up @@ -112,6 +116,10 @@ def load_route(pipeline: str) -> any:
from app.routes import live_video_to_video

return live_video_to_video.router
case "inpainting":
from app.routes import inpainting

return inpainting.router
case _:
raise EnvironmentError(f"{pipeline} is not a valid pipeline")

Expand Down
125 changes: 125 additions & 0 deletions runner/app/pipelines/inpainting.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
# runner/app/pipelines/inpainting.py

import logging
import os
from typing import List, Optional, Tuple

import PIL
import torch
from app.pipelines.base import Pipeline
from app.pipelines.utils import (
LoraLoader,
SafetyChecker,
get_model_dir,
get_torch_device,
split_prompt,
)
from app.utils.errors import InferenceError
from diffusers import AutoPipelineForInpainting, EulerAncestralDiscreteScheduler
from huggingface_hub import file_download
from PIL import Image, ImageFile

ImageFile.LOAD_TRUNCATED_IMAGES = True

logger = logging.getLogger(__name__)

class InpaintingPipeline(Pipeline):
def __init__(self, model_id: str):
self.model_id = model_id
kwargs = {"cache_dir": get_model_dir()}

torch_device = get_torch_device()
folder_name = file_download.repo_folder_name(
repo_id=model_id, repo_type="model"
)
folder_path = os.path.join(get_model_dir(), folder_name)

# Load fp16 variant if available
has_fp16_variant = any(
".fp16.safetensors" in fname
for _, _, files in os.walk(folder_path)
for fname in files
)
if torch_device != "cpu" and has_fp16_variant:
logger.info("InpaintingPipeline loading fp16 variant for %s", model_id)
kwargs["torch_dtype"] = torch.float16
kwargs["variant"] = "fp16"

self.ldm = AutoPipelineForInpainting.from_pretrained(
model_id,
safety_checker=None, # We'll use our own safety checker
**kwargs
).to(torch_device)

# Enable memory efficient attention if available
if hasattr(self.ldm, "enable_xformers_memory_efficient_attention"):
logger.info("Enabling xformers memory efficient attention")
self.ldm.enable_xformers_memory_efficient_attention()

# Initialize safety checker on specified device
safety_checker_device = os.getenv("SAFETY_CHECKER_DEVICE", "cuda").lower()
self._safety_checker = SafetyChecker(device=safety_checker_device)

# Initialize LoRA support
self._lora_loader = LoraLoader(self.ldm)

def __call__(
self, prompt: str, image: PIL.Image, mask_image: PIL.Image, **kwargs
) -> Tuple[List[PIL.Image], List[Optional[bool]]]:
# Extract parameters
seed = kwargs.pop("seed", None)
safety_check = kwargs.pop("safety_check", True)
loras_json = kwargs.pop("loras", "")

# Handle seed generation
if seed is not None:
if isinstance(seed, int):
kwargs["generator"] = torch.Generator(get_torch_device()).manual_seed(seed)
elif isinstance(seed, list):
kwargs["generator"] = [
torch.Generator(get_torch_device()).manual_seed(s) for s in seed
]

# Handle LoRA loading
if not loras_json:
self._lora_loader.disable_loras()
else:
self._lora_loader.load_loras(loras_json)

# Clean up inference steps if invalid
if "num_inference_steps" in kwargs and (
kwargs["num_inference_steps"] is None or kwargs["num_inference_steps"] < 1
):
del kwargs["num_inference_steps"]

# Split prompts if multiple are provided
prompts = split_prompt(prompt, max_splits=3)
kwargs.update(prompts)
neg_prompts = split_prompt(
kwargs.pop("negative_prompt", ""),
key_prefix="negative_prompt",
max_splits=3,
)
kwargs.update(neg_prompts)

try:
outputs = self.ldm(
prompt=prompt,
image=image,
mask_image=mask_image,
**kwargs
)
except Exception as e:
if isinstance(e, torch.cuda.OutOfMemoryError):
torch.cuda.empty_cache()
raise InferenceError(original_exception=e)

if safety_check:
_, has_nsfw_concept = self._safety_checker.check_nsfw_images(outputs.images)
else:
has_nsfw_concept = [None] * len(outputs.images)

return outputs.images, has_nsfw_concept

def __str__(self) -> str:
return f"InpaintingPipeline model_id={self.model_id}"
Loading