Skip to content
This repository has been archived by the owner on Jul 12, 2024. It is now read-only.

Commit

Permalink
add model browser support (#36)
Browse files Browse the repository at this point in the history
Signed-off-by: AnyISalIn <[email protected]>
  • Loading branch information
AnyISalIn authored Aug 27, 2023
1 parent e2e2f71 commit ce80c75
Show file tree
Hide file tree
Showing 5 changed files with 881 additions and 92 deletions.
184 changes: 117 additions & 67 deletions extension/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
import importlib

from omniinfer_client import *
from dataclass_wizard import JSONWizard, DumpMeta
from dataclasses import dataclass, field

from typing import Dict

Expand Down Expand Up @@ -52,34 +54,37 @@ def upscale(self, *args, **kwargs):
pass


class StableDiffusionModel(object):
def __init__(self,
kind,
name,
rating=0,
tags=None,
child=None,
examples=None,
user_tags=None):
self.kind = kind # checkpoint, lora
self.name = name
self.rating = rating
self.tags = tags
if self.tags is None:
self.tags = []

self.user_tags = user_tags
if self.user_tags is None:
self.user_tags = []

self.child = child
if self.child is None:
self.child = []

self.examples = examples

def append_child(self, child):
self.child.append(child)
class JSONe(JSONWizard):
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
DumpMeta(key_transform='SNAKE').bind_to(cls)


@dataclass
class StableDiffusionModelExample(JSONe):
prompts: Optional[str] = None
neg_prompt: Optional[str] = None
sampler_name: Optional[str] = None
steps: Optional[int] = None
seed: Optional[int] = None
height: Optional[int] = None
width: Optional[int] = None
preview: Optional[str] = None
cfg_scale: Optional[float] = None


@dataclass
class StableDiffusionModel(JSONe):
kind: str
name: str
rating: int = 0
tags: List[str] = None
child: Optional[List[str]] = field(default_factory=lambda: [])
examples: Optional[List[StableDiffusionModelExample]] = field(default_factory=lambda: [])
user_tags: Optional[List[str]] = field(default_factory=lambda: [])
preview_url: Optional[str] = None
search_terms: Optional[List[str]] = field(default_factory=lambda: [])
origin_url: Optional[str] = None

@property
def alias(self):
Expand All @@ -92,39 +97,35 @@ def alias(self):
if len(self.tags) > 0:
n = "[{}] ".format(",".join(self.tags))
return n + os.path.splitext(self.name)[0]

def to_json(self):
d = {}
for k, v in self.__dict__.items():
if isinstance(v, StableDiffusionModelExample):
d[k] = v.__dict__
else:
d[k] = v
return d


class StableDiffusionModelExample(object):

def __init__(self,
prompts=None,
neg_prompt=None,
sampler_name=None,
steps=None,
cfg_scale=None,
seed=None,
height=None,
width=None,
preview=None,
):
self.prompts = prompts
self.neg_prompt = neg_prompt
self.sampler_name = sampler_name
self.steps = steps
self.cfg_scale = cfg_scale
self.seed = seed
self.height = height
self.width = width
self.preview = preview

def add_user_tag(self, tag):
if tag not in self.user_tags:
self.user_tags.append(tag)



# class StableDiffusionModelExample(object):

# def __init__(self,
# prompts=None,
# neg_prompt=None,
# sampler_name=None,
# steps=None,
# cfg_scale=None,
# seed=None,
# height=None,
# width=None,
# preview=None,
# ):
# self.prompts = prompts
# self.neg_prompt = neg_prompt
# self.sampler_name = sampler_name
# self.steps = steps
# self.cfg_scale = cfg_scale
# self.seed = seed
# self.height = height
# self.width = width
# self.preview = preview


class OmniinferAPI(BaseAPI, UpscaleAPI):
Expand Down Expand Up @@ -159,6 +160,14 @@ def load_from_config(cls):
# if no key, we will set it to NONE
o._api_key = 'NONE'
o.update_client()

if config.get('models') is not None:
try:
o._models = [StableDiffusionModel.from_dict(m) for m in config['models']]
except Exception as exp:
print('[cloud-inference] failed to load models from config file {}, we will create a new one'.format(exp))
o._models = []

return o

@classmethod
Expand All @@ -180,6 +189,25 @@ def update_key_to_config(cls, key):
json.dumps(config, ensure_ascii=False, indent=2,
default=vars).encode('utf-8'))

@classmethod
def update_models_to_config(cls, models):
config = {}
if os.path.exists(OMNIINFER_CONFIG):
with open(OMNIINFER_CONFIG, 'r') as f:
try:
config = json.load(f)
except:
print(
'[cloud-inference] failed to load config file, we will create a new one'
)
pass

config['models'] = models
with open(OMNIINFER_CONFIG, 'wb+') as f:
f.write(
json.dumps(config, ensure_ascii=False, indent=2,
default=vars).encode('utf-8'))

@classmethod
def test_connection(cls, api_key: str):
client = OmniClient(api_key)
Expand Down Expand Up @@ -390,6 +418,11 @@ def get_models(type_):
for item in models:
model = StableDiffusionModel(kind=item.type.value,
name=item.sd_name)
model.search_terms = [
item.sd_name,
item.name,
str(item.civitai_version_id)
]
model.rating = item.civitai_download_count
civitai_tags = item.civitai_tags.split(",") if item.civitai_tags is not None else []

Expand All @@ -399,9 +432,12 @@ def get_models(type_):
if len(civitai_tags) > 0:
model.tags.append(civitai_tags[0])

if item.civitai_nsfw:
if item.civitai_nsfw or item.civitai_image_nsfw:
model.tags.append("nsfw")

if item.civitai_image_url:
model.preview_url = item.civitai_image_url

model.examples = []
if item.civitai_images:
for img in item.civitai_images:
Expand All @@ -426,11 +462,25 @@ def get_models(type_):
sd_models.extend(get_models(ModelType.CONTROLNET))
sd_models.extend(get_models(ModelType.VAE))
sd_models.extend(get_models(ModelType.UPSCALER))
sd_models.extend(get_models(ModelType.TEXT_INVERSION))

# build lora and checkpoint relationship

self._models = sd_models
return sd_models
merged_models = {}
origin_models = {}
for model in self._models:
origin_models[model.name] = model
for model in sd_models:
if model.name in origin_models:
# save user tags
merged_models[model.name] = model
merged_models[model.name].user_tags = origin_models[model.name].user_tags
else:
merged_models[model.name] = model

self._models = [v for k, v in merged_models.items()]
self.update_models_to_config(self._models)
return self._models


_instance = None
Expand Down Expand Up @@ -574,7 +624,7 @@ def prepare_mask(
def image_dict_from_any(image) -> Optional[Dict[str, np.ndarray]]:
if image is None:
return None

if isinstance(image, (tuple, list)):
image = {'image': image[0], 'mask': image[1]}
elif not isinstance(image, dict):
Expand All @@ -595,10 +645,10 @@ def image_dict_from_any(image) -> Optional[Dict[str, np.ndarray]]:
if image['image'] is None:
image['mask'] = None
return image

if 'mask' not in image:
image['mask'] = None

if isinstance(image['mask'], str):
if os.path.exists(image['mask']):
image['mask'] = np.array(Image.open(image['mask'])).astype('uint8')
Expand Down
Loading

0 comments on commit ce80c75

Please sign in to comment.