Skip to content
This repository has been archived by the owner on Jul 12, 2024. It is now read-only.

Commit

Permalink
Fix img2img button visible and refactor model dropdown (#11)
Browse files Browse the repository at this point in the history
* Fix img2img button visible
* Feat: refator cloud models dropdown
* Feat: support cloud model xyz plot

---------

Signed-off-by: AnyISalIn <[email protected]>
  • Loading branch information
AnyISalIn authored Jul 11, 2023
1 parent 3e85f3c commit 85938ac
Show file tree
Hide file tree
Showing 2 changed files with 315 additions and 184 deletions.
169 changes: 92 additions & 77 deletions extension/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,65 +32,75 @@ def refresh_models() -> list[str]:
pass


class CheckpointExample(object):
class StableDiffusionModel(object):

def __init__(self,
prompts=None,
neg_prompt=None,
sampler_name=None,
steps=None,
cfg_scale=None,
seed=None,
height=None,
width=None):
self.prompts = prompts
self.neg_prompt = neg_prompt
self.sampler_name = sampler_name
self.steps = steps
self.cfg_scale = cfg_scale
self.seed = seed
self.height = height
self.width = width


class Checkpoint(object):

def __init__(self,
name: str,
rating: int = None,
loras: list[str] = None,
tags: list[str] = None,
example: CheckpointExample = None):
kind,
name,
rating=0,
tags=None,
child=None,
example=None,
dependency_model_name=None):
self.kind = kind # checkpoint, lora
self.name = name
self.rating = rating
self.loras = loras
self.tags = tags
if self.tags is None:
self.tags = []

if self.loras is None:
self.loras = []
self.child = child
if self.child is None:
self.child = []

self.tags = []
if tags is not None:
self.tags = tags
self.example = example
self.dependency_model_name = dependency_model_name

def append_child(self, child):
self.child.append(child)

@property
def display_name(self):
n = ""
# format -> [<ckpt/lora>] [<tag>] <name>
kind = self.kind
if self.kind == 'checkpoint':
kind = 'ckpt'

n = "[{}] ".format(kind)

if self.tags is not None and len(self.tags) != 0:
n += "[{}] ".format(self.tags[0])
return n + self.name

def to_json(self):
d = {}
for k, v in self.__dict__.items():
if isinstance(v, CheckpointExample):
if isinstance(v, StableDiffusionModelExample):
d[k] = v.__dict__
else:
d[k] = v
return d

def add_lora(self, lora):
self.loras.append(lora)

class StableDiffusionModelExample(object):

def __init__(self,
prompts=None,
neg_prompt=None,
sampler_name=None,
steps=None,
cfg_scale=None,
seed=None,
height=None,
width=None):
self.prompts = prompts
self.neg_prompt = neg_prompt
self.sampler_name = sampler_name
self.steps = steps
self.cfg_scale = cfg_scale
self.seed = seed
self.height = height
self.width = width


class OmniinferAPI(BaseAPI):
Expand All @@ -115,12 +125,18 @@ def load_from_config(cls):
o = OmniinferAPI()
if config.get('key') is not None:
o._token = config['key']
if config.get('models') is not None:
o._models = []
for model in config['models']:
if model.get('example'):
model['example'] = CheckpointExample(**model['example'])
o._models.append(Checkpoint(**model))
try:
if config.get('models') is not None:
o._models = []
for model in config['models']:
if model.get('example'):
model['example'] = StableDiffusionModelExample(
**model['example'])
o._models.append(StableDiffusionModel(**model))
except Exception as e:
print(
'[cloud-inference] failed to load models from config file, we will create a new one'
)
return o

@classmethod
Expand Down Expand Up @@ -454,7 +470,8 @@ def check_controlnet_arg(self, p):

controlnet_arg = {}
controlnet_arg['weight'] = c.weight
controlnet_arg['model'] = "control_v11f1e_sd15_tile" # TODO
controlnet_arg[
'model'] = "control_v11f1e_sd15_tile" # TODO
controlnet_arg['module'] = c.module

if c.control_mode == "Balanced":
Expand Down Expand Up @@ -514,50 +531,48 @@ def refresh_models(self):
}

print("[cloud-inference] refreshing models...")
results = []
sd_models = []

res = requests.get(url, headers=headers)
if res.status_code >= 400:
return []
# return [_["name"] for _ in res.json()["data"]["models"]]
m = {}
for item in res.json()["data"]["models"]:
if item['type'] == 'checkpoint':
ckpt = Checkpoint(
name=item["sd_name"],
rating=item.get("civitai_download_count", 0),
tags=item["civitai_tags"].split(",") if item.get("civitai_tags", None) is not None else []
)

if len(item.get(
'civitai_images',
[])) > 0 and item['civitai_images'][0]['meta'].get(
'prompt') is not None:
first_image = item['civitai_images'][0]['meta']
ckpt.example = CheckpointExample(
prompts=first_image['prompt'],
neg_prompt=first_image.get('negative_prompt', None),
width=first_image.get('width', None),
height=first_image.get('height', None),
sampler_name=first_image.get('sampler_name', None),
cfg_scale=first_image.get('cfg_scale', None),
seed=first_image.get('seed', None),
)
m[item['sd_name']] = ckpt
for item in res.json()["data"]["models"]:
model = StableDiffusionModel(kind=item["type"],
name=item["sd_name"])
model.rating = item.get("civitai_download_count", 0)
model.tags = item["civitai_tags"].split(",") if item.get(
"civitai_tags", None) is not None else []

if len(item.get('civitai_images',
[])) > 0 and item['civitai_images'][0]['meta'].get(
'prompt') is not None:
first_image = item['civitai_images'][0]['meta']
model.example = StableDiffusionModelExample(
prompts=first_image['prompt'],
neg_prompt=first_image.get('negative_prompt', None),
width=first_image.get('width', None),
height=first_image.get('height', None),
sampler_name=first_image.get('sampler_name', None),
cfg_scale=first_image.get('cfg_scale', None),
seed=first_image.get('seed', None))
if item['type'] == 'lora':
civitai_dependency_model_name = item.get(
'civitai_dependency_model_name', None)
if civitai_dependency_model_name is not None:
if m.get(civitai_dependency_model_name) is not None:
m[civitai_dependency_model_name].add_lora(
item['sd_name'])
model.dependency_model_name = civitai_dependency_model_name
sd_models.append(model)

m = {}
for model in sd_models:
m[model.name] = model

for _, ckpt in m.items():
results.append(ckpt)
for _, model in m.items():
if model.dependency_model_name is not None:
if m.get(model.dependency_model_name) is not None:
m[model.dependency_model_name].append_child(model.name)

self.__class__.update_models_to_config(results)
return results
self.__class__.update_models_to_config(sd_models)
return sd_models


def retrieve_images(img_urls) -> list[Image.Image]:
Expand Down
Loading

0 comments on commit 85938ac

Please sign in to comment.