diff --git a/extension/api.py b/extension/api.py index dd94e63..697ed00 100644 --- a/extension/api.py +++ b/extension/api.py @@ -97,13 +97,12 @@ def alias(self): if len(self.tags) > 0: n = "[{}] ".format(",".join(self.tags)) return n + os.path.splitext(self.name)[0] - + def add_user_tag(self, tag): if tag not in self.user_tags: self.user_tags.append(tag) - # class StableDiffusionModelExample(object): # def __init__(self, @@ -266,11 +265,10 @@ def img2img( else: save_kwargs = {} - buffered = io.BytesIO() - i.save(buffered, format=live_previews_image_format, **save_kwargs) - base64_image = base64.b64encode( - buffered.getvalue()).decode('ascii') - images_base64.append(base64_image) + with io.BytesIO() as buffered: + i.save(buffered, format=live_previews_image_format, **save_kwargs) + base64_image = base64.b64encode(buffered.getvalue()).decode('ascii') + images_base64.append(base64_image) def _req(p: processing.StableDiffusionProcessingImg2Img, controlnet_units): req = Img2ImgRequest( @@ -302,8 +300,16 @@ def _req(p: processing.StableDiffusionProcessingImg2Img, controlnet_units): if 'sd_vae' in p._cloud_inference_settings: req.sd_vae = p._cloud_inference_settings['sd_vae'] + if hasattr(p, 'refiner_checkpoint') and p.refiner_checkpoint is not None and p.refiner_checkpoint != "None": + req.sd_refiner = Refiner( + checkpoint=p.refiner_checkpoint, + switch_at=p.refiner_switch_at, + ) + if len(controlnet_units) > 0: req.controlnet_units = controlnet_units + if opts.data.get("control_net_no_detectmap", False): + req.controlnet_no_detectmap = True res = self._client.sync_img2img(req, download_images=False, callback=self._update_state) return res.data.imgs @@ -352,6 +358,14 @@ def _req(p: processing.StableDiffusionProcessingTxt2Img, controlnet_units): if len(controlnet_units) > 0: req.controlnet_units = controlnet_units + if opts.data.get("control_net_no_detectmap", False): + req.controlnet_no_detectmap = True + + if hasattr(p, 'refiner_checkpoint') and p.refiner_checkpoint is not None and p.refiner_checkpoint != "None": + req.sd_refiner = Refiner( + checkpoint=p.refiner_checkpoint, + switch_at=p.refiner_switch_at, + ) res = self._client.sync_txt2img(req, download_images=False, callback=self._update_state) if res.data.status != ProgressResponseStatusCode.SUCCESSFUL: @@ -477,7 +491,7 @@ def get_models(type_): merged_models[model.name].user_tags = origin_models[model.name].user_tags else: merged_models[model.name] = model - + self._models = [v for k, v in merged_models.items()] self.update_models_to_config(self._models) return self._models @@ -668,7 +682,8 @@ def _download(img_url): while attempts > 0: try: response = requests.get(img_url, timeout=2) - return Image.open(io.BytesIO(response.content)) + with io.BytesIO(response.content) as fp: + return Image.open(fp).copy() except Exception: print("[cloud-inference] failed to download image, retrying...") attempts -= 1 diff --git a/extension/version.py b/extension/version.py index 283b03a..a23ef3f 100644 --- a/extension/version.py +++ b/extension/version.py @@ -1 +1 @@ -__version__ = "0.1.7" \ No newline at end of file +__version__ = "0.1.8" \ No newline at end of file diff --git a/install.py b/install.py index 78acc03..485ef78 100644 --- a/install.py +++ b/install.py @@ -1,3 +1,3 @@ import launch -launch.run_pip("install omniinfer_client==0.3.3", "requirements for sd-webui-cloud-inference") +launch.run_pip("install omniinfer_client==0.3.5", "requirements for sd-webui-cloud-inference") diff --git a/scripts/main_ui.py b/scripts/main_ui.py index b3ffcec..4f2e4a4 100644 --- a/scripts/main_ui.py +++ b/scripts/main_ui.py @@ -18,6 +18,7 @@ class FormComponent: def get_expected_parent(self): return gr.components.Form + class FormButton(FormComponent, gr.Button): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -25,6 +26,7 @@ def __init__(self, *args, **kwargs): def get_block_name(self): return "button" + class ToolButton(FormComponent, gr.Button): """Small button with single emoji as text, fits inside gradio forms""" @@ -71,6 +73,12 @@ def __init__(self): self.remote_model_vaes = None self.remote_model_upscalers = None + # refiner + self.txt2img_checkpoint = None + self.img2img_checkpoint = None + self.txt2img_checkpoint_refresh = None + self.img2img_checkpoint_refresh = None + # third component self.txt2img_controlnet_model_dropdown_units = [] self.img2img_controlnet_model_dropdown_units = [] @@ -86,10 +94,13 @@ def __init__(self): self.extras_upscaler_2_original = None self.txt2img_controlnet_model_dropdown_original_units = [] self.img2img_controlnet_model_dropdown_original_units = [] + self.txt2img_checkpoint_original = None + self.img2img_checkpoint_original = None self.default_remote_model = None self.initialized = False + self.bultin_refiner_supported = False self.ext_controlnet_installed = False def on_selected_model(self, name_index: int, selected_loras: list[str], selected_embedding: list[str], suggest_prompts_enabled, prompt: str, neg_prompt: str): @@ -208,8 +219,6 @@ def find_name_by_alias(self, choice): # return gr.update(value=build_model_browser_html_for_checkpoint("txt2img", _binding.remote_model_checkpoints)), \ # gr.update(value=build_model_browser_html_for_loras("txt2img", _binding.remote_model_loras)), \ # gr.update(value=build_model_browser_html_for_embeddings("txt2img", _binding.remote_model_embeddings)), \ - - def _get_kind_from_remote_models(models, kind): @@ -278,8 +287,8 @@ def ui(self, is_img2img): value=lambda: _binding.default_remote_model, elem_id="{}_cloud_inference_model_dropdown".format(tabname), scale=2) - - model_browser_button = FormButton(value="{} Browser".format(model_browser_symbol), elem_classes='model-browser-button', elem_id="{}_cloud_inference_browser_button".format(tabname), scale=0) + model_browser_button = FormButton(value="{} Browser".format(model_browser_symbol), elem_classes='model-browser-button', + elem_id="{}_cloud_inference_browser_button".format(tabname), scale=0) refresh_button = ToolButton(value=refresh_symbol, elem_id="{}_cloud_inference_refersh_button".format(tabname)) # model_browser_button = ToolButton(model_browser_symbol, elem_id="{}_cloud_inference_browser_button".format(tabname)) @@ -317,7 +326,6 @@ def ui(self, is_img2img): with gr.Tab(label="Embedding", elem_id='{}_model_browser_embedding_tab'.format(tabname)): model_embedding_browser_dialog_html = gr.HTML(build_model_browser_html_for_embeddings(tabname, _binding.remote_model_embeddings)) - checkpoint_model_browser_dialog.visible = False model_browser_button.click(fn=lambda: gr.update(visible=True), inputs=[], outputs=[checkpoint_model_browser_dialog],).\ then(fn=None, _js="function(){ modelBrowserPopup('" + tabname + "', gradioApp().getElementById('" + checkpoint_model_browser_dialog.elem_id + "')); }", show_progress=True) @@ -428,12 +436,12 @@ def wrapper(): _binding.update_models() return gr.update(choices=[_.alias for _ in _binding.remote_model_checkpoints]), \ - gr.update(choices=[_.alias for _ in _binding.remote_model_loras]), \ - gr.update(choices=["Automatic", "None"] + [_.name for _ in _binding.remote_model_vaes]), \ - gr.update(choices=[_.alias for _ in _binding.remote_model_embeddings]), \ - gr.update(value=build_model_browser_html_for_checkpoint(tab, _binding.remote_model_checkpoints)), \ - gr.update(value=build_model_browser_html_for_loras(tab, _binding.remote_model_loras)), \ - gr.update(value=build_model_browser_html_for_embeddings(tab, _binding.remote_model_embeddings)) + gr.update(choices=[_.alias for _ in _binding.remote_model_loras]), \ + gr.update(choices=["Automatic", "None"] + [_.name for _ in _binding.remote_model_vaes]), \ + gr.update(choices=[_.alias for _ in _binding.remote_model_embeddings]), \ + gr.update(value=build_model_browser_html_for_checkpoint(tab, _binding.remote_model_checkpoints)), \ + gr.update(value=build_model_browser_html_for_loras(tab, _binding.remote_model_loras)), \ + gr.update(value=build_model_browser_html_for_embeddings(tab, _binding.remote_model_embeddings)) return wrapper refresh_button.click( @@ -462,6 +470,12 @@ def wrapper(): if os.path.isdir(os.path.join(paths_internal.extensions_dir, "sd-webui-controlnet")) and 'sd-webui-controlnet' not in shared.opts.data.get('disabled_extensions', []): _binding.ext_controlnet_installed = True + try: + import modules.processing_scripts.refiner + _binding.bultin_refiner_supported = True + except: + pass + from scripts.hijack import _hijack_manager _hijack_manager._binding = _binding _hijack_manager.hijack_onload() @@ -544,6 +558,33 @@ def on_after_component_callback(component, **_kwargs): component.choices = [_.alias for _ in _binding.remote_model_upscalers] component.value = component.choices[0] + # txt2img refiner + if type(component) is gr.Dropdown and getattr(component, 'elem_id', None) == 'txt2img_checkpoint': + _binding.txt2img_checkpoint = component + _binding.txt2img_checkpoint_original = component.get_config() + + if _binding.remote_inference_enabled: + component.choices = ["None"] + [_.name for _ in _binding.remote_model_checkpoints if 'refiner' in _.name] # TODO + component.value = component.choices[0] + if gr.Dropdown and getattr(component, 'elem_id', None) == 'txt2img_checkpoint_refresh': + _binding.txt2img_checkpoint_refresh = component + if _binding.remote_inference_enabled: + component.visible = False + + # img2img refiner + if type(component) is gr.Dropdown and getattr(component, 'elem_id', None) == 'img2img_checkpoint': + _binding.img2img_checkpoint = component + _binding.img2img_checkpoint_original = component.get_config() + + if _binding.remote_inference_enabled: + component.choices = ["None"] + [_.name for _ in _binding.remote_model_checkpoints if 'refiner' in _.name] # TODO + component.value = component.choices[0] + + if gr.Dropdown and getattr(component, 'elem_id', None) == 'img2img_checkpoint_refresh': + _binding.img2img_checkpoint_refresh = component + if _binding.remote_inference_enabled: + component.visible = False + if _binding.txt2img_cloud_inference_checkbox and \ _binding.img2img_cloud_inference_checkbox and \ _binding.txt2img_cloud_inference_model_dropdown and \ @@ -562,32 +603,15 @@ def on_after_component_callback(component, **_kwargs): if expect_unit_amount != len(_binding.txt2img_controlnet_model_dropdown_units): return + if _binding.bultin_refiner_supported: + if _binding.txt2img_checkpoint is None or _binding.img2img_checkpoint is None: + return + sync_cloud_model(_binding.txt2img_cloud_inference_model_dropdown, _binding.img2img_cloud_inference_model_dropdown) - sync_two_component(_binding.txt2img_cloud_inference_suggest_prompts_checkbox, _binding.img2img_cloud_inference_suggest_prompts_checkbox, 'change') - - if not _binding.ext_controlnet_installed: - on_cloud_inference_checkbox_change_without_controlnet(_binding.txt2img_cloud_inference_checkbox, - _binding.img2img_cloud_inference_checkbox, - _binding.txt2img_generate, - _binding.img2img_generate, - _binding.extras_upscaler_1, - _binding.extras_upscaler_2, - _binding.txt2img_hr_upscaler - ) - else: - on_cloud_inference_checkbox_change(_binding.txt2img_cloud_inference_checkbox, - _binding.img2img_cloud_inference_checkbox, - _binding.txt2img_generate, - _binding.img2img_generate, - _binding.txt2img_controlnet_model_dropdown_units, - _binding.img2img_controlnet_model_dropdown_units, - _binding.extras_upscaler_1, - _binding.extras_upscaler_2, - _binding.txt2img_hr_upscaler - ) + on_cloud_inference_checkbox_change(_binding) _binding.initialized = True @@ -728,14 +752,7 @@ def mirror(a, b): getattr(b, "change")(fn=mirror, inputs=[b, a], outputs=[b, a]) -def on_cloud_inference_checkbox_change_without_controlnet(txt2img_checkbox, - img2img_checkbox, - txt2img_generate_button, - img2img_generate_button, - extras_upscaler_1, - extras_upscaler_2, - txt2img_hr_upscaler - ): +def on_cloud_inference_checkbox_change(binding: DataBinding): def mirror(source, target): enabled = source @@ -744,145 +761,104 @@ def mirror(source, target): button_text = "Generate" if enabled: - _binding.remote_inference_enabled = True + binding.remote_inference_enabled = True button_text = "Generate (cloud)" else: - _binding.remote_inference_enabled = False + binding.remote_inference_enabled = False - upscale_models_with_none = ["None"] + [_.alias for _ in _binding.remote_model_upscalers] - upscale_models = [_.alias for _ in _binding.remote_model_upscalers] + controlnet_models = ["None"] + [_.name for _ in binding.remote_model_controlnet] + upscale_models_with_none = ["None"] + [_.alias for _ in binding.remote_model_upscalers] + upscale_models = [_.alias for _ in binding.remote_model_upscalers] + refiner_models = ["None"] + [_.name for _ in binding.remote_model_checkpoints if 'refiner' in _.name] # TODO - if not enabled: - allow_update_fields = ['value', 'choices'] - extras_upscaler_1_config = {k: v for k, v in _binding.extras_upscaler_1_original.items() if k in allow_update_fields} - extras_upscaler_2_config = {k: v for k, v in _binding.extras_upscaler_2_original.items() if k in allow_update_fields} - txt2img_hr_upscaler_config = {k: v for k, v in _binding.txt2img_hr_upscaler_original.items() if k in allow_update_fields} - - return source, \ - target, \ - button_text, \ - button_text, \ - gr.update(**extras_upscaler_1_config), \ - gr.update(**extras_upscaler_2_config), \ - gr.update(**txt2img_hr_upscaler_config) - - return source, \ - target, \ - button_text,\ - button_text,\ - gr.update(value=upscale_models[0], choices=upscale_models), \ - gr.update(value=upscale_models_with_none[0], choices=upscale_models_with_none), \ - gr.update(value=upscale_models[0], choices=upscale_models) \ - - txt2img_checkbox.change(fn=mirror, - inputs=[txt2img_checkbox, - img2img_checkbox], - outputs=[ - txt2img_checkbox, - img2img_checkbox, - txt2img_generate_button, - img2img_generate_button, - extras_upscaler_1, - extras_upscaler_2, - txt2img_hr_upscaler - ]) - img2img_checkbox.change(fn=mirror, - inputs=[img2img_checkbox, - txt2img_checkbox - ], - outputs=[ - img2img_checkbox, - txt2img_checkbox, - txt2img_generate_button, - img2img_generate_button, - extras_upscaler_1, - extras_upscaler_2, - txt2img_hr_upscaler - ]) - - -def on_cloud_inference_checkbox_change(txt2img_checkbox, - img2img_checkbox, - txt2img_generate_button, - img2img_generate_button, - txt2img_controlnet_model_dropdown_units, - img2img_controlnet_model_dropdown_units, - extras_upscaler_1, - extras_upscaler_2, - txt2img_hr_upscaler - ): - def mirror(source, target): - enabled = source + update_components = ( + source, + target, + button_text, + button_text, + ) - if source != target: - target = source + def back_to_original(origin_config): + allow_update_fields = ['value', 'choices'] + return {k: v for k, v in origin_config.items() if k in allow_update_fields} - button_text = "Generate" - if enabled: - _binding.remote_inference_enabled = True - button_text = "Generate (cloud)" - else: - _binding.remote_inference_enabled = False + if not enabled: + update_components += ( + gr.update(**back_to_original(binding.extras_upscaler_1_original)), + gr.update(**back_to_original(binding.extras_upscaler_2_original)), + gr.update(**back_to_original(binding.txt2img_hr_upscaler_original)) + ) + if binding.ext_controlnet_installed: + update_components += ( + *[gr.update(**back_to_original(_)) for _ in binding.txt2img_controlnet_model_dropdown_original_units], + *[gr.update(**back_to_original(_)) for _ in binding.img2img_controlnet_model_dropdown_original_units], + ) + if binding.bultin_refiner_supported: + update_components += ( + gr.update(**back_to_original(binding.txt2img_checkpoint_original)), + gr.update(**back_to_original(binding.img2img_checkpoint_original)), + gr.update(visible=True), + gr.update(visible=True), + ) - controlnet_models = ["None"] + [_.name for _ in _binding.remote_model_controlnet] - upscale_models_with_none = ["None"] + [_.alias for _ in _binding.remote_model_upscalers] - upscale_models = [_.alias for _ in _binding.remote_model_upscalers] + return update_components + + update_components += ( + gr.update(value=upscale_models[0], choices=upscale_models), + gr.update(value=upscale_models_with_none[0], choices=upscale_models_with_none), + gr.update(value=upscale_models[0], choices=upscale_models), + ) + if binding.ext_controlnet_installed: + update_components += ( + *[gr.update(value=controlnet_models[0], choices=controlnet_models) for _ in binding.txt2img_controlnet_model_dropdown_units], + *[gr.update(value=controlnet_models[0], choices=controlnet_models) for _ in binding.img2img_controlnet_model_dropdown_units], + ) + if binding.bultin_refiner_supported: + update_components += ( + gr.update(value=refiner_models[0], choices=refiner_models), + gr.update(value=refiner_models[0], choices=refiner_models), + gr.update(visible=False), + gr.update(visible=False), + ) - if not enabled: - allow_update_fields = ['value', 'choices'] - extras_upscaler_1_config = {k: v for k, v in _binding.extras_upscaler_1_original.items() if k in allow_update_fields} - extras_upscaler_2_config = {k: v for k, v in _binding.extras_upscaler_2_original.items() if k in allow_update_fields} - txt2img_hr_upscaler_config = {k: v for k, v in _binding.txt2img_hr_upscaler_original.items() if k in allow_update_fields} - - return source, \ - target, \ - button_text, \ - button_text, \ - *[gr.update(**{k: v for k, v in _.items() if k in allow_update_fields}) for _ in _binding.txt2img_controlnet_model_dropdown_original_units], \ - *[gr.update(**{k: v for k, v in _.items() if k in allow_update_fields}) for _ in _binding.img2img_controlnet_model_dropdown_original_units], \ - gr.update(**extras_upscaler_1_config), \ - gr.update(**extras_upscaler_2_config), \ - gr.update(**txt2img_hr_upscaler_config) - - return source, \ - target, \ - button_text,\ - button_text,\ - *[gr.update(value=controlnet_models[0], choices=controlnet_models) for _ in txt2img_controlnet_model_dropdown_units], \ - *[gr.update(value=controlnet_models[0], choices=controlnet_models) for _ in img2img_controlnet_model_dropdown_units], \ - gr.update(value=upscale_models[0], choices=upscale_models), \ - gr.update(value=upscale_models_with_none[0], choices=upscale_models_with_none), \ - gr.update(value=upscale_models[0], choices=upscale_models) \ - - txt2img_checkbox.change(fn=mirror, - inputs=[txt2img_checkbox, - img2img_checkbox], - outputs=[ - txt2img_checkbox, - img2img_checkbox, - txt2img_generate_button, - img2img_generate_button, - *txt2img_controlnet_model_dropdown_units, - *img2img_controlnet_model_dropdown_units, - extras_upscaler_1, - extras_upscaler_2, - txt2img_hr_upscaler - ]) - img2img_checkbox.change(fn=mirror, - inputs=[img2img_checkbox, - txt2img_checkbox - ], - outputs=[ - img2img_checkbox, - txt2img_checkbox, - txt2img_generate_button, - img2img_generate_button, - *txt2img_controlnet_model_dropdown_units, - *img2img_controlnet_model_dropdown_units, - extras_upscaler_1, - extras_upscaler_2, - txt2img_hr_upscaler - ]) + return update_components + + expect_update_components = ( + _binding.txt2img_generate, + _binding.img2img_generate, + _binding.extras_upscaler_1, + _binding.extras_upscaler_2, + _binding.txt2img_hr_upscaler + ) + if _binding.ext_controlnet_installed: + expect_update_components += ( + *_binding.txt2img_controlnet_model_dropdown_units, + *_binding.img2img_controlnet_model_dropdown_units, + ) + if _binding.bultin_refiner_supported: + expect_update_components += ( + _binding.txt2img_checkpoint, + _binding.img2img_checkpoint, + _binding.txt2img_checkpoint_refresh, + _binding.img2img_checkpoint_refresh, + ) + + _binding.txt2img_cloud_inference_checkbox.change(fn=mirror, + inputs=[_binding.txt2img_cloud_inference_checkbox, + _binding.img2img_cloud_inference_checkbox, + ], + outputs=[ + _binding.img2img_cloud_inference_checkbox, + _binding.txt2img_cloud_inference_checkbox, + *expect_update_components]) + _binding.img2img_cloud_inference_checkbox.change(fn=mirror, + inputs=[_binding.img2img_cloud_inference_checkbox, + _binding.txt2img_cloud_inference_checkbox], + outputs=[ + _binding.img2img_cloud_inference_checkbox, + _binding.txt2img_cloud_inference_checkbox, + *expect_update_components + ]) def on_ui_settings(): diff --git a/scripts/network_ui.py b/scripts/network_ui.py deleted file mode 100644 index 1d4cf3e..0000000 --- a/scripts/network_ui.py +++ /dev/null @@ -1,80 +0,0 @@ -from pathlib import Path - -from modules import shared, ui_extra_networks_user_metadata, errors, extra_networks, script_callbacks, ui_extra_networks -from modules.images import read_info_from_image, save_image_with_geninfo -import gradio as gr -import json -import html -from fastapi.exceptions import HTTPException - -from modules.generation_parameters_copypaste import image_from_url_text -from modules.ui_components import ToolButton - - -class ExtraNetworksPage: - def __init__(self, title): - self.title = title - self.name = title.lower() - self.id_page = self.name.replace(" ", "_") - self.card_page = shared.html("extra-networks-card.html") - self.allow_negative_prompt = False - self.metadata = {} - self.items = {} - - def refresh(self): - pass - - def create_html(self, tabname): - pass - - def create_item(self, name, index=None, enable_filter=True): - return { - "name": "test", - "filename": "test", - "preview": "https://next-app-static.s3.amazonaws.com/images-prod/xG1nkqKTMzGDvpLrqFT7WA/7e6f18a0-e02a-4934-70e8-359c4c302f00/width=450/53255.jpeg", - "description": "1234", - "search_term": "1234", - "metadata": {}, - } - - def list_items(self): - return [{ - "name": "test", - "filename": "test", - "preview": "https://next-app-static.s3.amazonaws.com/images-prod/xG1nkqKTMzGDvpLrqFT7WA/7e6f18a0-e02a-4934-70e8-359c4c302f00/width=450/53255.jpeg", - "description": "1234", - "search_term": "1234", - # "onclick": '"' + html.escape(f"""return selectCheckpoint({quote_js(name)})""") + '"', - # "local_preview": f"{path}.{shared.opts.samples_format}", - "metadata": {}, - # "sort_keys": {'default': index, **self.get_sort_keys(checkpoint.filename)}, - }] - - def allowed_directories_for_previews(self): - return [] - - def create_html_for_item(self, item, tabname): - import ipdb; ipdb.set_trace() - return shared.html("abc") - - def get_sort_keys(self, path): - """ - List of default keys used for sorting in the UI. - """ - pth = Path(path) - stat = pth.stat() - return { - "date_created": int(stat.st_ctime or 0), - "date_modified": int(stat.st_mtime or 0), - "name": pth.name.lower(), - } - - def create_user_metadata_editor(self, ui, tabname): - return ui_extra_networks_user_metadata.UserMetadataEditor(ui, tabname, self) - - -def register_page(*args, **kwargs): - ui_extra_networks.register_page(ExtraNetworksPage("Cloud Models")) - - -script_callbacks.on_before_ui(register_page)