From ce80c75aa4bac53b2c24e1a66cfacbd6ce77cb83 Mon Sep 17 00:00:00 2001 From: AnyISalIn Date: Sun, 27 Aug 2023 17:10:10 +0800 Subject: [PATCH] add model browser support (#36) Signed-off-by: AnyISalIn --- extension/api.py | 184 +++++++++++++-------- javascript/modelBrowser.js | 164 +++++++++++++++++++ scripts/main_ui.py | 317 ++++++++++++++++++++++++++++++++++--- scripts/network_ui.py | 80 ++++++++++ style.css | 228 ++++++++++++++++++++++++++ 5 files changed, 881 insertions(+), 92 deletions(-) create mode 100644 javascript/modelBrowser.js create mode 100644 scripts/network_ui.py create mode 100644 style.css diff --git a/extension/api.py b/extension/api.py index f1d6164..dd94e63 100644 --- a/extension/api.py +++ b/extension/api.py @@ -8,6 +8,8 @@ import importlib from omniinfer_client import * +from dataclass_wizard import JSONWizard, DumpMeta +from dataclasses import dataclass, field from typing import Dict @@ -52,34 +54,37 @@ def upscale(self, *args, **kwargs): pass -class StableDiffusionModel(object): - def __init__(self, - kind, - name, - rating=0, - tags=None, - child=None, - examples=None, - user_tags=None): - self.kind = kind # checkpoint, lora - self.name = name - self.rating = rating - self.tags = tags - if self.tags is None: - self.tags = [] - - self.user_tags = user_tags - if self.user_tags is None: - self.user_tags = [] - - self.child = child - if self.child is None: - self.child = [] - - self.examples = examples - - def append_child(self, child): - self.child.append(child) +class JSONe(JSONWizard): + def __init_subclass__(cls, **kwargs): + super().__init_subclass__(**kwargs) + DumpMeta(key_transform='SNAKE').bind_to(cls) + + +@dataclass +class StableDiffusionModelExample(JSONe): + prompts: Optional[str] = None + neg_prompt: Optional[str] = None + sampler_name: Optional[str] = None + steps: Optional[int] = None + seed: Optional[int] = None + height: Optional[int] = None + width: Optional[int] = None + preview: Optional[str] = None + cfg_scale: Optional[float] = None + + +@dataclass +class StableDiffusionModel(JSONe): + kind: str + name: str + rating: int = 0 + tags: List[str] = None + child: Optional[List[str]] = field(default_factory=lambda: []) + examples: Optional[List[StableDiffusionModelExample]] = field(default_factory=lambda: []) + user_tags: Optional[List[str]] = field(default_factory=lambda: []) + preview_url: Optional[str] = None + search_terms: Optional[List[str]] = field(default_factory=lambda: []) + origin_url: Optional[str] = None @property def alias(self): @@ -92,39 +97,35 @@ def alias(self): if len(self.tags) > 0: n = "[{}] ".format(",".join(self.tags)) return n + os.path.splitext(self.name)[0] - - def to_json(self): - d = {} - for k, v in self.__dict__.items(): - if isinstance(v, StableDiffusionModelExample): - d[k] = v.__dict__ - else: - d[k] = v - return d - - -class StableDiffusionModelExample(object): - - def __init__(self, - prompts=None, - neg_prompt=None, - sampler_name=None, - steps=None, - cfg_scale=None, - seed=None, - height=None, - width=None, - preview=None, - ): - self.prompts = prompts - self.neg_prompt = neg_prompt - self.sampler_name = sampler_name - self.steps = steps - self.cfg_scale = cfg_scale - self.seed = seed - self.height = height - self.width = width - self.preview = preview + + def add_user_tag(self, tag): + if tag not in self.user_tags: + self.user_tags.append(tag) + + + +# class StableDiffusionModelExample(object): + +# def __init__(self, +# prompts=None, +# neg_prompt=None, +# sampler_name=None, +# steps=None, +# cfg_scale=None, +# seed=None, +# height=None, +# width=None, +# preview=None, +# ): +# self.prompts = prompts +# self.neg_prompt = neg_prompt +# self.sampler_name = sampler_name +# self.steps = steps +# self.cfg_scale = cfg_scale +# self.seed = seed +# self.height = height +# self.width = width +# self.preview = preview class OmniinferAPI(BaseAPI, UpscaleAPI): @@ -159,6 +160,14 @@ def load_from_config(cls): # if no key, we will set it to NONE o._api_key = 'NONE' o.update_client() + + if config.get('models') is not None: + try: + o._models = [StableDiffusionModel.from_dict(m) for m in config['models']] + except Exception as exp: + print('[cloud-inference] failed to load models from config file {}, we will create a new one'.format(exp)) + o._models = [] + return o @classmethod @@ -180,6 +189,25 @@ def update_key_to_config(cls, key): json.dumps(config, ensure_ascii=False, indent=2, default=vars).encode('utf-8')) + @classmethod + def update_models_to_config(cls, models): + config = {} + if os.path.exists(OMNIINFER_CONFIG): + with open(OMNIINFER_CONFIG, 'r') as f: + try: + config = json.load(f) + except: + print( + '[cloud-inference] failed to load config file, we will create a new one' + ) + pass + + config['models'] = models + with open(OMNIINFER_CONFIG, 'wb+') as f: + f.write( + json.dumps(config, ensure_ascii=False, indent=2, + default=vars).encode('utf-8')) + @classmethod def test_connection(cls, api_key: str): client = OmniClient(api_key) @@ -390,6 +418,11 @@ def get_models(type_): for item in models: model = StableDiffusionModel(kind=item.type.value, name=item.sd_name) + model.search_terms = [ + item.sd_name, + item.name, + str(item.civitai_version_id) + ] model.rating = item.civitai_download_count civitai_tags = item.civitai_tags.split(",") if item.civitai_tags is not None else [] @@ -399,9 +432,12 @@ def get_models(type_): if len(civitai_tags) > 0: model.tags.append(civitai_tags[0]) - if item.civitai_nsfw: + if item.civitai_nsfw or item.civitai_image_nsfw: model.tags.append("nsfw") + if item.civitai_image_url: + model.preview_url = item.civitai_image_url + model.examples = [] if item.civitai_images: for img in item.civitai_images: @@ -426,11 +462,25 @@ def get_models(type_): sd_models.extend(get_models(ModelType.CONTROLNET)) sd_models.extend(get_models(ModelType.VAE)) sd_models.extend(get_models(ModelType.UPSCALER)) + sd_models.extend(get_models(ModelType.TEXT_INVERSION)) # build lora and checkpoint relationship - self._models = sd_models - return sd_models + merged_models = {} + origin_models = {} + for model in self._models: + origin_models[model.name] = model + for model in sd_models: + if model.name in origin_models: + # save user tags + merged_models[model.name] = model + merged_models[model.name].user_tags = origin_models[model.name].user_tags + else: + merged_models[model.name] = model + + self._models = [v for k, v in merged_models.items()] + self.update_models_to_config(self._models) + return self._models _instance = None @@ -574,7 +624,7 @@ def prepare_mask( def image_dict_from_any(image) -> Optional[Dict[str, np.ndarray]]: if image is None: return None - + if isinstance(image, (tuple, list)): image = {'image': image[0], 'mask': image[1]} elif not isinstance(image, dict): @@ -595,10 +645,10 @@ def image_dict_from_any(image) -> Optional[Dict[str, np.ndarray]]: if image['image'] is None: image['mask'] = None return image - + if 'mask' not in image: image['mask'] = None - + if isinstance(image['mask'], str): if os.path.exists(image['mask']): image['mask'] = np.array(Image.open(image['mask'])).astype('uint8') diff --git a/javascript/modelBrowser.js b/javascript/modelBrowser.js new file mode 100644 index 0000000..b3e86a5 --- /dev/null +++ b/javascript/modelBrowser.js @@ -0,0 +1,164 @@ +var globalModelBrowserPopup = null; +var globalModelBrowserPopupInner = null; +var globalModelBrowserListeners = []; +function closeModelBrowserPopup() { + if (!globalModelBrowserPopup) return; + + globalModelBrowserPopup.style.display = "none"; +} + +function modelBrowserPopup(tab, contents) { + if (!globalModelBrowserPopup) { + globalModelBrowserPopup = document.createElement('div'); + globalModelBrowserPopup.onclick = closeModelBrowserPopup; + globalModelBrowserPopup.classList.add('global-model-browser-popup'); + + var close = document.createElement('div'); + close.classList.add('global-model-browser-popup-close'); + close.onclick = closeModelBrowserPopup; + close.title = "Close"; + globalModelBrowserPopup.appendChild(close); + + globalModelBrowserPopupInner = document.createElement('div'); + globalModelBrowserPopupInner.onclick = function (event) { + event.stopPropagation(); + return false; + }; + globalModelBrowserPopupInner.classList.add('global-model-browser-popup-inner'); + globalModelBrowserPopup.appendChild(globalModelBrowserPopupInner); + + gradioApp().querySelector('.main').appendChild(globalModelBrowserPopup); + } + + doThingsAfterPopup(tab) + globalModelBrowserPopupInner.innerHTML = ''; + globalModelBrowserPopupInner.appendChild(contents); + + globalModelBrowserPopup.style.display = "flex"; +} + + + +function toggleSelected(button) { + const filterButtons = document.querySelectorAll('.filter-btn'); + + filterButtons.forEach(btn => { + btn.classList.remove('selected'); + }); + button.classList.add('selected'); +} + +function filterImages(tab, kind, selectedTag) { + const imageItems = document.querySelectorAll(`.image-item[data-kind="${kind}"]`); + + imageItems.forEach(item => { + const itemTags = item.getAttribute('data-tags').split(' '); + const shouldDisplay = selectedTag === 'all' || itemTags.includes(selectedTag); + item.style.display = shouldDisplay ? 'block' : 'none'; + }); +} + +function doThingsAfterPopup(tab) { + addFilterButtons(tab, 'checkpoint'); + addFilterButtons(tab, 'lora'); + addFilterButtons(tab, 'embedding'); + + applyTextSearch(tab, 'checkpoint'); + applyTextSearch(tab, 'lora'); + applyTextSearch(tab, 'embedding'); + + // addNsfwToggle() + addImageClickListener(tab); + + + applyNsfwClass(tab); +} + +function addImageClickListener(tab) { + const imageItems = document.querySelectorAll('.image-item'); + + imageItems.forEach(item => { + const selectButton = item.querySelector('#select-button'); + // const favoriteButton = item.querySelector('#favorite-btn'); + const titleElement = item.querySelector('.title').getAttribute('data-alias'); + const browserTabName = item.parentElement.parentElement.parentElement.querySelector('.heading-text').textContent + selectButton.addEventListener('click', (event) => { + if (browserTabName == 'CHECKPOINT Browser') { + desiredCloudInferenceCheckpointName = titleElement; + gradioApp().getElementById(`${tab}_change_cloud_checkpoint`).click() + } else if (browserTabName == 'LORA Browser') { + desiredCloudInferenceLoraName = titleElement; + gradioApp().getElementById(`${tab}_change_cloud_lora`).click() + } else if (browserTabName == 'EMBEDDING Browser') { + desiredCloudInferenceEmbeddingName = titleElement; + gradioApp().getElementById(`${tab}_change_cloud_embedding`).click() + } + }); + + // favoriteButton.addEventListener('click', () => { + // desciredCloudInferenceFavoriteModelName = titleElement; + // gradioApp().getElementById(`${tab}_favorite`).click() + // }) + }) +} + +function addFilterButtons(tab, kind) { + const filterButtons = document.querySelectorAll(`.filter-btn[data-kind="${kind}"][data-tab="${tab}"]`); + // Filter images based on selected filter + filterButtons.forEach(button => { + button.addEventListener('click', () => { + const selectedTag = button.getAttribute('data-tag'); + filterImages(tab, kind, selectedTag); + + toggleSelected(button); // Add selected style to clicked button + }); + }); +} + + +function applyNsfwClass(tab) { + const imageItems = document.querySelectorAll('.image-item'); + + imageItems.forEach(item => { + const itemTags = item.getAttribute('data-tags').split(' '); + if (itemTags.includes('nsfw')) { + item.classList.add('nsfw'); + } else { + item.classList.remove('nsfw'); + } + }); +} + +function applyTextSearch(tab, kind) { + document.getElementById(`${tab}-${kind}-filter-search-input`).addEventListener(`input`, () => { + const searchText = document.getElementById(`${tab}-${kind}-filter-search-input`).value.toLowerCase(); + const selectedTag = getSelectedTag(kind); + + const imageItems = document.querySelectorAll('.image-item'); + document.querySelectorAll(`.image-item[data-kind="${kind}"]`) + imageItems.forEach(item => { + const itemTags = item.getAttribute('data-tags').split(' '); + const searchTerms = item.getAttribute('data-search-terms'); + const matchesSearch = searchTerms.toLowerCase().includes(searchText.toLowerCase()); + const matchesTag = selectedTag === 'all' || itemTags.includes(selectedTag); + console.log(item, matchesSearch, matchesTag) + item.style.display = matchesSearch && matchesTag ? 'block' : 'none'; + }); + }); +} + +// function doThingsAfterClosePopup() { +// for (kind of ['checkpoint', 'lora', 'embedding']) { +// searchInput.removeEventListener(`${kind}-search-input-listener`) +// } +// } + +function getSelectedTag(kind) { + const selectedButton = document.querySelector(`.filter-btn.selected[data-kind="${kind}"]`); + return selectedButton ? selectedButton.getAttribute('data-tag') : 'all'; +} + +function openInNewTab(url) { + var win = window.open(url, '_blank'); + win.focus(); +} \ No newline at end of file diff --git a/scripts/main_ui.py b/scripts/main_ui.py index 73dc713..9df30ea 100644 --- a/scripts/main_ui.py +++ b/scripts/main_ui.py @@ -2,20 +2,28 @@ import gradio as gr import os -from modules import script_callbacks, shared, paths_internal +from modules import script_callbacks, shared, paths_internal, ui_common from extension import api +from collections import Counter import random refresh_symbol = '\U0001f504' # 🔄 favorite_symbol = '\U0001f49e' # 💞 +model_browser_symbol = '\U0001f50d' # 🔍 class FormComponent: def get_expected_parent(self): return gr.components.Form +class FormButton(FormComponent, gr.Button): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def get_block_name(self): + return "button" class ToolButton(FormComponent, gr.Button): """Small button with single emoji as text, fits inside gradio forms""" @@ -57,6 +65,7 @@ def __init__(self): self.remote_models = None self.remote_models_aliases = {} self.remote_model_checkpoints = None + self.remote_model_embeddings = None self.remote_model_loras = None self.remote_model_controlnet = None self.remote_model_vaes = None @@ -83,7 +92,7 @@ def __init__(self): self.ext_controlnet_installed = False - def on_selected_model(self, name_index: int, selected_loras: list[str], suggest_prompts_enabled, prompt: str, neg_prompt: str): + def on_selected_model(self, name_index: int, selected_loras: list[str], selected_embedding: list[str], suggest_prompts_enabled, prompt: str, neg_prompt: str): selected: api.StableDiffusionModel = self.find_model_by_alias(name_index) selected_checkpoint = selected @@ -101,16 +110,18 @@ def on_selected_model(self, name_index: int, selected_loras: list[str], suggest_ if suggest_prompts_enabled and example.neg_prompt: neg_prompt = example.neg_prompt neg_prompt = neg_prompt.replace("\n", "") + if len(selected_embedding) > 0: + neg_prompt = self._update_embedding_in_neg_prompt(neg_prompt, selected_embedding) return gr.Dropdown.update( - choices=[_.alias for _ in self.remote_model_checkpoints], - value=selected_checkpoint.alias), gr.update(value=selected_loras), gr.update(value=prompt), gr.update(value=neg_prompt) + choices=[_.alias for _ in self.remote_model_checkpoints], value=selected_checkpoint.alias), gr.update(value=prompt), gr.update(value=neg_prompt) def update_models(self): for model in self.remote_models: self.remote_models_aliases[model.alias] = model _binding.remote_model_loras = _get_kind_from_remote_models(_binding.remote_models, "lora") + _binding.remote_model_embeddings = _get_kind_from_remote_models(_binding.remote_models, "textualinversion") _binding.remote_model_checkpoints = _get_kind_from_remote_models(_binding.remote_models, "checkpoint") _binding.remote_model_vaes = _get_kind_from_remote_models(_binding.remote_models, "vae") _binding.remote_model_controlnet = _get_kind_from_remote_models(_binding.remote_models, "controlnet") @@ -143,9 +154,37 @@ def _update_lora_in_prompt(prompt, _lora_names, weight=1): return ", ".join(prompt_split) + @staticmethod + def _update_embedding_in_neg_prompt(neg_prompt, _embedding_names): + embedding_names = [] + for embedding_name in _embedding_names: + name = _binding.find_model_by_alias(embedding_name).name.rsplit(".", 1)[0] # remove extension + embedding_names.append(name) + + neg_prompt = neg_prompt + add_embedding_prompts = [] + + neg_prompt_split = [_.strip() for _ in neg_prompt.split(',')] + + # add + for embedding_name in embedding_names: + if embedding_name not in neg_prompt: + add_embedding_prompts.append(embedding_name) + # delete + for prompt_item in neg_prompt_split: + if embedding_name not in embedding_names: + neg_prompt_split.remove(prompt_item) + + neg_prompt_split.extend(add_embedding_prompts) + + return ", ".join(neg_prompt_split) + def update_selected_lora(self, lora_names, prompt): return gr.update(value=self._update_lora_in_prompt(prompt, lora_names)) + def update_selected_embedding(self, embedding_names, neg_prompt): + return gr.update(value=self._update_embedding_in_neg_prompt(neg_prompt, embedding_names)) + def update_cloud_api(self, v): self.cloud_api = v @@ -159,6 +198,19 @@ def find_name_by_alias(self, choice): if model.alias == choice: return model.name + # def update_model_favorite(self, alias): + # model = self.find_model_by_alias(alias) + # if model is not None: + # if "favorite" in model.tags: + # model.tags.remove("favorite") + # else: + # model.tags.append("favorite") + # return gr.update(value=build_model_browser_html_for_checkpoint("txt2img", _binding.remote_model_checkpoints)), \ + # gr.update(value=build_model_browser_html_for_loras("txt2img", _binding.remote_model_loras)), \ + # gr.update(value=build_model_browser_html_for_embeddings("txt2img", _binding.remote_model_embeddings)), \ + + + def _get_kind_from_remote_models(models, kind): t = [] @@ -216,18 +268,21 @@ def ui(self, is_img2img): label="Service Provider", choices=["Omniinfer"], value="Omniinfer", - elem_id="{}_cloud_api_dropdown".format(tabname) + elem_id="{}_cloud_api_dropdown".format(tabname), + scale=1 ) cloud_inference_model_dropdown = gr.Dropdown( label="Checkpoint", - choices=[ - _.alias for _ in _binding.remote_model_checkpoints], + choices=[_.alias for _ in _binding.remote_model_checkpoints], value=lambda: _binding.default_remote_model, - elem_id="{}_cloud_inference_model_dropdown".format(tabname)) + elem_id="{}_cloud_inference_model_dropdown".format(tabname), scale=2) + - refresh_button = ToolButton( - value=refresh_symbol, elem_id="{}_cloud_inference_refersh_button".format(tabname)) + model_browser_button = FormButton(value="{} Browser".format(model_browser_symbol), elem_classes='model-browser-button', elem_id="{}_cloud_inference_browser_button".format(tabname), scale=0) + refresh_button = ToolButton(value=refresh_symbol, elem_id="{}_cloud_inference_refersh_button".format(tabname)) + + # model_browser_button = ToolButton(model_browser_symbol, elem_id="{}_cloud_inference_browser_button".format(tabname)) # favorite_button = ToolButton( # value=favorite_symbol, elem_id="{}_cloud_inference_favorite_button".format(tabname)) @@ -236,6 +291,10 @@ def ui(self, is_img2img): choices=[_.alias for _ in _binding.remote_model_loras], label="Lora", elem_id="{}_cloud_inference_lora_dropdown", multiselect=True, scale=4) + cloud_inference_embedding_dropdown = gr.Dropdown( + choices=[_.alias for _ in _binding.remote_model_embeddings], + label="Embedding", + elem_id="{}_cloud_inference_embedding_dropdown", multiselect=True, scale=4) cloud_inference_extra_checkbox = gr.Checkbox( label="Extra", @@ -244,6 +303,25 @@ def ui(self, is_img2img): scale=1 ) + # functionally + hide_button_change_checkpoint = gr.Button('Change Cloud checkpoint', elem_id='{}_change_cloud_checkpoint'.format(tabname), visible=False) + hide_button_change_lora = gr.Button('Change Cloud LORA', elem_id='{}_change_cloud_lora'.format(tabname), visible=False) + hide_button_change_embedding = gr.Button('Change Cloud Embedding', elem_id='{}_change_cloud_embedding'.format(tabname), visible=False) + # hide_button_favorite = gr.Button('Favorite', elem_id='{}_favorite'.format(tabname), visible=False) + + with gr.Box(elem_id='{}_model_browser'.format(tabname), elem_classes="popup-model-browser", visbile=False) as checkpoint_model_browser_dialog: + with gr.Tab(label="Checkpoint", elem_id='{}_model_browser_checkpoint_tab'.format(tabname)): + model_checkpoint_browser_dialog_html = gr.HTML(build_model_browser_html_for_checkpoint(tabname, _binding.remote_model_checkpoints)) + with gr.Tab(label="LORA", elem_id='{}_model_browser_lora_tab'.format(tabname)): + model_lora_browser_dialog_html = gr.HTML(build_model_browser_html_for_loras(tabname, _binding.remote_model_loras)) + with gr.Tab(label="Embedding", elem_id='{}_model_browser_embedding_tab'.format(tabname)): + model_embedding_browser_dialog_html = gr.HTML(build_model_browser_html_for_embeddings(tabname, _binding.remote_model_embeddings)) + + + checkpoint_model_browser_dialog.visible = False + model_browser_button.click(fn=lambda: gr.update(visible=True), inputs=[], outputs=[checkpoint_model_browser_dialog],).\ + then(fn=None, _js="function(){ modelBrowserPopup('" + tabname + "', gradioApp().getElementById('" + checkpoint_model_browser_dialog.elem_id + "')); }", show_progress=True) + with gr.Row(visible=False) as extra_row: cloud_inference_vae_dropdown = gr.Dropdown( choices=["Automatic", "None"] + [_.name for _ in _binding.remote_model_vaes], @@ -255,25 +333,47 @@ def ui(self, is_img2img): cloud_inference_extra_checkbox.change(lambda x: gr.update(visible=x), inputs=[ cloud_inference_extra_checkbox], outputs=[extra_row]) + # lora # define events of components. # auto fill prompt after select model - cloud_inference_model_dropdown.select( + hide_button_change_checkpoint.click( fn=_binding.on_selected_model, + _js="function(a, b, c, d, e, f){ var res = desiredCloudInferenceCheckpointName; desiredCloudInferenceCheckpointName = ''; return [res, b, c, d, e, f]; }", inputs=[ cloud_inference_model_dropdown, cloud_inference_lora_dropdown, + cloud_inference_embedding_dropdown, cloud_inference_suggest_prompts_checkbox, getattr(_binding, "{}_prompt".format(tabname)), getattr(_binding, "{}_neg_prompt".format(tabname)) - ], outputs=[ cloud_inference_model_dropdown, - cloud_inference_lora_dropdown, getattr(_binding, "{}_prompt".format(tabname)), getattr(_binding, "{}_neg_prompt".format(tabname)) - ]) - + ] + ) + # dummy_component = gr.Label(visible=False) + # hide_button_favorite.click( + # fn=_binding.update_model_favorite, + # _js='''function(){ name = desciredCloudInferenceFavoriteModelName; desciredCloudInferenceFavoriteModelName = ""; return [name]; }''', + # inputs=[dummy_component], + # outputs=[ + # model_checkpoint_browser_dialog_html, + # model_lora_browser_dialog_html, + # model_embedding_browser_dialog_html, + # ], + # ) + + hide_button_change_lora.click( + fn=lambda x, y: _binding.update_selected_lora(x, y), + _js="function(a, b){ a.includes(desiredCloudInferenceLoraName) || a.push(desiredCloudInferenceLoraName); desiredCloudInferenceLoraName = ''; return [a, b]; }", + inputs=[ + cloud_inference_lora_dropdown, + getattr(_binding, "{}_prompt".format(tabname)) + ], + outputs=getattr(_binding, "{}_prompt".format(tabname)), + ) # auto fill prompt after select lora cloud_inference_lora_dropdown.select( fn=lambda x, y: _binding.update_selected_lora(x, y), @@ -281,23 +381,72 @@ def ui(self, is_img2img): cloud_inference_lora_dropdown, getattr(_binding, "{}_prompt".format(tabname)) ], - outputs=getattr(_binding, "{}_prompt".format(tabname)), ) - def _model_refresh(): - api.get_instance().refresh_models() - _binding.remote_models = api.get_instance().list_models() - _binding.update_models() + hide_button_change_embedding.click( + fn=lambda x, y: _binding.update_selected_embedding(x, y), + _js="function(a, b){ a.includes(desiredCloudInferenceEmbeddingName) || a.push(desiredCloudInferenceEmbeddingName); desiredCloudInferenceEmbeddingName = ''; return [a, b]; }", + inputs=[ + cloud_inference_embedding_dropdown, + getattr(_binding, "{}_neg_prompt".format(tabname)) + ], + outputs=getattr(_binding, "{}_neg_prompt".format(tabname)), + ) + # embeddings + cloud_inference_embedding_dropdown.select( + fn=lambda x, y: _binding.update_selected_embedding(x, y), + inputs=[ + cloud_inference_embedding_dropdown, + getattr(_binding, "{}_neg_prompt".format(tabname)) + ], + outputs=[ + getattr(_binding, "{}_neg_prompt".format(tabname)), + ] + ) + + cloud_inference_model_dropdown.select( + fn=_binding.on_selected_model, + inputs=[ + cloud_inference_model_dropdown, + cloud_inference_lora_dropdown, + cloud_inference_embedding_dropdown, + cloud_inference_suggest_prompts_checkbox, + getattr(_binding, "{}_prompt".format(tabname)), + getattr(_binding, "{}_neg_prompt".format(tabname)) + ], + outputs=[ + cloud_inference_model_dropdown, + getattr(_binding, "{}_prompt".format(tabname)), + getattr(_binding, "{}_neg_prompt".format(tabname)) + ]) - return gr.update(choices=[_.alias for _ in _binding.remote_model_checkpoints]), gr.update(choices=[_.alias for _ in _binding.remote_model_loras]), gr.update(choices=["Automatic", "None"] + [_.name for _ in _binding.remote_model_vaes]) + def _model_refresh(tab): + def wrapper(): + api.get_instance().refresh_models() + _binding.remote_models = api.get_instance().list_models() + _binding.update_models() + + return gr.update(choices=[_.alias for _ in _binding.remote_model_checkpoints]), \ + gr.update(choices=[_.alias for _ in _binding.remote_model_loras]), \ + gr.update(choices=["Automatic", "None"] + [_.name for _ in _binding.remote_model_vaes]), \ + gr.update(choices=[_.alias for _ in _binding.remote_model_embeddings]), \ + gr.update(value=build_model_browser_html_for_checkpoint(tab, _binding.remote_model_checkpoints)), \ + gr.update(value=build_model_browser_html_for_loras(tab, _binding.remote_model_loras)), \ + gr.update(value=build_model_browser_html_for_embeddings(tab, _binding.remote_model_embeddings)) + return wrapper refresh_button.click( - fn=_model_refresh, + fn=_model_refresh(tabname), inputs=[], outputs=[cloud_inference_model_dropdown, cloud_inference_lora_dropdown, - cloud_inference_vae_dropdown + cloud_inference_embedding_dropdown, + cloud_inference_vae_dropdown, + + model_checkpoint_browser_dialog_html, + model_lora_browser_dialog_html, + model_embedding_browser_dialog_html, ]) return [cloud_inference_checkbox, cloud_inference_model_dropdown, cloud_inference_vae_dropdown] @@ -443,6 +592,124 @@ def on_after_component_callback(component, **_kwargs): _binding.initialized = True +def build_model_browser_html_for_checkpoint(tab, checkpoints): + column_html = "" + column_size = 5 + column_items = [[] for _ in range(column_size)] + tag_counter = Counter() + kind = "checkpoint" + for i, model in enumerate(checkpoints): + trimed_tags = [_.replace(" ", "_") for _ in model.tags] + tag_counter.update(trimed_tags) + if model.preview_url is None or not model.preview_url.startswith("http"): + model.preview_url = "https://via.placeholder.com/512x512.png?text=Preview+Not+Available" + model_html = f"""
+ +
+
{model.name.rsplit(".", 1)[0]}
+
+
+
+
+ +
+
""" + column_index = i % column_size + column_items[column_index].append(model_html) + + for i in range(column_size): + column_image_items_html = "" + for item in column_items[i]: + column_image_items_html += item + column_html += """
{}
""".format(column_image_items_html) + + tag_html = f"""
+ + """ + tag_html += """{}
""" + tag_html = tag_html.format("\n".join([f"""""" for _ in tag_counter.most_common()])) + + return f"""

{kind.upper()} Browser

{tag_html} + + """ + + +def build_model_browser_html_for_loras(tab, loras): + column_html = "" + column_size = 5 + column_items = [[] for _ in range(column_size)] + tag_counter = Counter() + kind = "lora" + for i, model in enumerate(loras): + trimed_tags = [_.replace(" ", "_") for _ in model.tags] + tag_counter.update(trimed_tags) + model_html = f"""
+ +
+
{model.name.rsplit(".", 1)[0]}
+
+
+
+
+ +
+
""" + column_index = i % column_size + column_items[column_index].append(model_html) + + for i in range(column_size): + column_image_items_html = "" + for item in column_items[i]: + column_image_items_html += item + column_html += """
{}
""".format(column_image_items_html) + + tag_html = f"""
+ + """ + tag_html += """{}
""" + tag_html = tag_html.format("\n".join([f"""""" for _ in tag_counter.most_common()])) + + return f"""

{kind.upper()} Browser

{tag_html}""" + + +def build_model_browser_html_for_embeddings(tab, embeddings): + column_html = "" + column_size = 5 + column_items = [[] for _ in range(column_size)] + tag_counter = Counter() + kind = "embedding" + for i, model in enumerate(embeddings): + trimed_tags = [_.replace(" ", "_") for _ in model.tags] + tag_counter.update(trimed_tags) + model_html = f"""
+ +
+
{model.name.rsplit(".", 1)[0]}
+
+
+
+
+ +
+
""" + column_index = i % column_size + column_items[column_index].append(model_html) + + for i in range(column_size): + column_image_items_html = "" + for item in column_items[i]: + column_image_items_html += item + column_html += """
{}
""".format(column_image_items_html) + + tag_html = f"""
+ + """ + tag_html += """{}
""" + tag_html = tag_html.format("\n".join([f"""""" for _ in tag_counter.most_common()])) + + return f"""

{kind.upper()} Browser

{tag_html}""" + + def sync_two_component(a, b, event_name): def mirror(a, b): if a != b: @@ -457,8 +724,8 @@ def mirror(a, b): if a != b: b = a return a, b - getattr(a, "select")(fn=mirror, inputs=[a, b], outputs=[a, b]) - getattr(b, "select")(fn=mirror, inputs=[b, a], outputs=[b, a]) + getattr(a, "change")(fn=mirror, inputs=[a, b], outputs=[a, b]) + getattr(b, "change")(fn=mirror, inputs=[b, a], outputs=[b, a]) def on_cloud_inference_checkbox_change_without_controlnet(txt2img_checkbox, diff --git a/scripts/network_ui.py b/scripts/network_ui.py new file mode 100644 index 0000000..1d4cf3e --- /dev/null +++ b/scripts/network_ui.py @@ -0,0 +1,80 @@ +from pathlib import Path + +from modules import shared, ui_extra_networks_user_metadata, errors, extra_networks, script_callbacks, ui_extra_networks +from modules.images import read_info_from_image, save_image_with_geninfo +import gradio as gr +import json +import html +from fastapi.exceptions import HTTPException + +from modules.generation_parameters_copypaste import image_from_url_text +from modules.ui_components import ToolButton + + +class ExtraNetworksPage: + def __init__(self, title): + self.title = title + self.name = title.lower() + self.id_page = self.name.replace(" ", "_") + self.card_page = shared.html("extra-networks-card.html") + self.allow_negative_prompt = False + self.metadata = {} + self.items = {} + + def refresh(self): + pass + + def create_html(self, tabname): + pass + + def create_item(self, name, index=None, enable_filter=True): + return { + "name": "test", + "filename": "test", + "preview": "https://next-app-static.s3.amazonaws.com/images-prod/xG1nkqKTMzGDvpLrqFT7WA/7e6f18a0-e02a-4934-70e8-359c4c302f00/width=450/53255.jpeg", + "description": "1234", + "search_term": "1234", + "metadata": {}, + } + + def list_items(self): + return [{ + "name": "test", + "filename": "test", + "preview": "https://next-app-static.s3.amazonaws.com/images-prod/xG1nkqKTMzGDvpLrqFT7WA/7e6f18a0-e02a-4934-70e8-359c4c302f00/width=450/53255.jpeg", + "description": "1234", + "search_term": "1234", + # "onclick": '"' + html.escape(f"""return selectCheckpoint({quote_js(name)})""") + '"', + # "local_preview": f"{path}.{shared.opts.samples_format}", + "metadata": {}, + # "sort_keys": {'default': index, **self.get_sort_keys(checkpoint.filename)}, + }] + + def allowed_directories_for_previews(self): + return [] + + def create_html_for_item(self, item, tabname): + import ipdb; ipdb.set_trace() + return shared.html("abc") + + def get_sort_keys(self, path): + """ + List of default keys used for sorting in the UI. + """ + pth = Path(path) + stat = pth.stat() + return { + "date_created": int(stat.st_ctime or 0), + "date_modified": int(stat.st_mtime or 0), + "name": pth.name.lower(), + } + + def create_user_metadata_editor(self, ui, tabname): + return ui_extra_networks_user_metadata.UserMetadataEditor(ui, tabname, self) + + +def register_page(*args, **kwargs): + ui_extra_networks.register_page(ExtraNetworksPage("Cloud Models")) + + +script_callbacks.on_before_ui(register_page) diff --git a/style.css b/style.css new file mode 100644 index 0000000..66325a9 --- /dev/null +++ b/style.css @@ -0,0 +1,228 @@ +.search-container { + text-align: center; + margin: 20px 0; +} + +#search-input { + padding: 10px; + width: 80%; + border: 1px solid #ccc; + border-radius: 5px; + font-size: 16px; +} + +.title-container { + background: rgba(214, 214, 214, 0.6); + padding: 1px; + /* Adjucompact padding for a better look */ + border-radius: 5px; + position: absolute; + bottom: 0; + left: 0; + width: 100%; + display: flex; + align-items: center; + /* Center vertically */ + justify-content: space-between; + /* Spread items horizontally */ +} + +.title { + font-weight: bold; + font-size: 12px; + /* Adjust font size for better visibility */ + color: var(--background-fill-primary); + /* Adjust text color */ + margin: 0; + /* Reset margin for the title */ +} + +.buttons { + position: absolute; + top: 10px; + right: 10px; + display: flex; + gap: 5px; +} + +.btn { + background-color: var(--background-fill-primary); + color: #fff; + border: none; + padding: 5px 10px; + border-radius: 5px; + cursor: pointer; +} + + + +.heading-text { + margin-bottom: 2rem; + font-size: 2rem; +} + +.heading-text span { + font-weight: 100; +} + +/* Responsive image gallery rules begin*/ + +.image-gallery { + /* Mobile first */ + /* max-height: 500px; */ + overflow-y: auto; + display: flex; + flex-direction: column; + gap: 10px; +} + +.image-gallery .column { + display: flex; + flex-direction: column; + gap: 10px; +} + +.image-item img { + width: 100%; + border-radius: 5px; + height: 100%; + object-fit: cover; +} + +.image-item.nsfw img { + filter: blur(10px); +} + +@media only screen and (min-width: 768px) { + .image-gallery { + flex-direction: row; + } +} + +/* overlay styles */ + +.image-item { + position: relative; + cursor: pointer; + min-height: 200px; +} + +.overlay { + position: absolute; + width: 100%; + height: 100%; + background: rgba(57, 57, 57, 0.502); + top: 0; + left: 0; + transform: scale(0); + transition: all 0.2s 0.1s ease-in-out; + color: #fff; + /* center overlay content */ + display: flex; + align-items: center; + justify-content: center; +} + +/* hover */ +.image-item:hover .overlay { + transform: scale(1); +} + +.filter-buttons { + gap: 10px; + margin-top: 10px; +} + +.filter-buttons button { + background-color: var(--background-fill-primary); + color: #fff; + border: none; + padding: 5px 10px; + border-radius: 5px; + cursor: pointer; +} + +/* Style for selected filter button */ +.filter-buttons button.selected { + background-color: var(--background-fill-primary); + /* Change color for selected button */ +} + +/* Style for selected filter button */ +.filter-buttons button.selected { + background-color: var(--block-label-text-color); + color: var(--background-fill-primary); + border-top-left-radius: 5px; + border-top-right-radius: 5px; + border: 1px solid #ccc; + border-bottom: none; +} + +#select-button { + background-color: transparent; + color: white; +} + + +.search-bar { + margin-top: 10px; + padding: 10px; +} + +.filter-search-input { + color: black; +} + + +.global-model-browser-popup { + display: flex; + position: fixed; + z-index: 1001; + left: 0; + top: 0; + width: 100%; + height: 100%; + overflow: auto; + /* background-color: rgba(20, 20, 20, 0.95); */ +} + +.global-model-browser-popup * { + box-sizing: border-box; +} + +.global-model-browser-popup-close:before { + content: "×"; +} + +.global-model-browser-popup-close { + position: fixed; + right: 0.25em; + top: 0; + cursor: pointer; + color: var(--block-label-text-color); + font-size: 32pt; +} + +.global-model-browser-popup-inner { + display: inline-block; + margin: auto; + padding: 2em; +} + +div.block.gradio-box.popup-model-browser { + position: absolute; + left: 50%; + width: 40%; + top: 40%; + height: 60%; + background: var(--body-background-fill); + /* padding: 2em !important; */ +} + + +.gradio-button.model-browser-button { + height: 2.4em; + align-self: end; + line-height: 1em; + border-radius: 0.5em; +} \ No newline at end of file