Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add SharedLoraLoader node #154

Merged
merged 4 commits into from
Sep 29, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 50 additions & 0 deletions nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -798,3 +798,53 @@ def INPUT_TYPES(s):
# FUNCTION = "encode"

CATEGORY = "conditioning/inpaint"


class SharedLoraLoader(BizyAir_LoraLoader):
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"share_id": ("STRING", {"default": "share_id"}),
"lora_name": ("STRING", {"default": "lora_name"}),
"model": (data_types.MODEL,),
"clip": (data_types.CLIP,),
"strength_model": (
"FLOAT",
{"default": 1.0, "min": -100.0, "max": 100.0, "step": 0.01},
),
"strength_clip": (
"FLOAT",
{"default": 1.0, "min": -100.0, "max": 100.0, "step": 0.01},
),
}
}

RETURN_TYPES = (data_types.MODEL, data_types.CLIP)
RETURN_NAMES = ("MODEL", "CLIP")
FUNCTION = "shared_load_lora"
CATEGORY = f"{PREFIX}/loaders"
NODE_DISPLAY_NAME = "Shared Lora Loader"

@classmethod
def VALIDATE_INPUTS(cls, share_id: str, lora_name: str):
if lora_name in folder_paths.filename_path_mapping.get("loras", {}):
return True

outs = folder_paths.get_share_filename_list("loras", share_id=share_id)
if lora_name not in outs:
raise ValueError(
f"Lora {lora_name} not found in share {share_id} with {outs}"
)
return True

def shared_load_lora(
self, model, clip, lora_name, strength_model, strength_clip, **kwargs
):
return super().load_lora(
model=model,
clip=clip,
lora_name=lora_name,
strength_model=strength_model,
strength_clip=strength_clip,
)
6 changes: 5 additions & 1 deletion src/bizy_server/modelhost.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,10 @@ async def submit_upload(request):
self.uploads[upload_id]["type"] = json_data["type"]
self.uploads[upload_id]["name"] = json_data["name"]
self.upload_queue.put(self.uploads[upload_id])

# enable refresh for lora
# TODO: enable refresh for other types
bizyair.path_utils.path_manager.enable_refresh_options("loras")
return OKResponse(None)

@prompt_server.routes.get(f"/{API_PREFIX}/models/files")
Expand Down Expand Up @@ -286,12 +290,12 @@ async def list_share_model_files(request):

if "ext_name" in request.rel_url.query:
payload["ext_name"] = request.rel_url.query["ext_name"]

model_files, err = await self.get_share_model_files(
shareId=shareId, payload=payload
)
if err is not None:
return ErrResponse(err)

return OKResponse(model_files)

@prompt_server.routes.delete(f"/{API_PREFIX}/models")
Expand Down
2 changes: 2 additions & 0 deletions src/bizyair/path_utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from .path_manager import (
convert_prompt_label_path_to_real_path,
disable_refresh_options,
enable_refresh_options,
get_filename_list,
guess_config,
guess_url_from_node,
Expand Down
73 changes: 62 additions & 11 deletions src/bizyair/path_utils/path_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import pprint
import re
import warnings
from dataclasses import dataclass
from typing import Any, Dict, List, Union

from ..common import fetch_models_by_type
Expand All @@ -24,6 +25,34 @@
filename_path_mapping: dict[str, dict[str, str]] = {}


@dataclass
class RefreshSettings:
loras: bool = True

def get(self, folder_name: str, default: bool = True):
return getattr(self, folder_name, default)

def set(self, folder_name: str, value: bool):
setattr(self, folder_name, value)


refresh_settings = RefreshSettings()


def enable_refresh_options(folder_names: Union[str, list[str]]):
if isinstance(folder_names, str):
folder_names = [folder_names]
for folder_name in folder_names:
refresh_settings.set(folder_name, True)


def disable_refresh_options(folder_names: Union[str, list[str]]):
if isinstance(folder_names, str):
folder_names = [folder_names]
for folder_name in folder_names:
refresh_settings.set(folder_name, False)


def _get_config_path():
src_bizyair_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
configs_path = os.path.join(src_bizyair_path, "configs")
Expand Down Expand Up @@ -88,26 +117,37 @@ def get_config_file_list(base_path=None) -> list:


def cached_filename_list(
folder_name: str, *, verbose=False, refresh=False
folder_name: str, *, share_id: str = None, verbose=False, refresh=False
) -> list[str]:
global filename_path_mapping
if refresh or folder_name not in filename_path_mapping:
model_types: Dict[str, str] = models_config["model_types"]
url = get_service_route(models_config["model_hub"]["find_model"])
if share_id:
url = f"{BIZYAIR_SERVER_ADDRESS}/{share_id}/models/files"
else:
url = get_service_route(models_config["model_hub"]["find_model"])
msg = fetch_models_by_type(
url=url, method="GET", model_type=model_types[folder_name]
)
if verbose:
pprint.pprint({"cached_filename_list": msg})

if not msg or "data" not in msg or msg["data"] is None:
try:
if not msg or "data" not in msg or msg["data"] is None:
return []

filename_path_mapping[folder_name] = {
x["label_path"]: x["real_path"]
for x in msg["data"]["files"]
if x["label_path"]
}
except Exception as e:
warnings.warn(f"Failed to get filename list: {e}")
return []

filename_path_mapping[folder_name] = {
x["label_path"]: x["real_path"]
for x in msg["data"]["files"]
if x["label_path"]
}
finally:
# TODO fix share_id vaild refresh settings
if share_id is None:
disable_refresh_options(folder_name)

return list(
filter_files_extensions(
Expand Down Expand Up @@ -139,11 +179,23 @@ def convert_prompt_label_path_to_real_path(prompt: dict[str, dict[str, any]]) ->
return new_prompt


def get_share_filename_list(folder_name, share_id, *, verbose=BIZYAIR_DEBUG):
assert share_id is not None and isinstance(share_id, str)
# TODO fix share_id vaild refresh settings
return cached_filename_list(
folder_name, share_id=share_id, verbose=verbose, refresh=True
)


def get_filename_list(folder_name, *, verbose=BIZYAIR_DEBUG):

global folder_names_and_paths
results = []
if folder_name in models_config["model_types"]:
results.extend(cached_filename_list(folder_name, verbose=verbose, refresh=True))
refresh = refresh_settings.get(folder_name, True)
ccssu marked this conversation as resolved.
Show resolved Hide resolved
results.extend(
cached_filename_list(folder_name, verbose=verbose, refresh=refresh)
)
if folder_name in folder_names_and_paths:
results.extend(folder_names_and_paths[folder_name])
if BIZYAIR_DEBUG:
Expand All @@ -153,7 +205,6 @@ def get_filename_list(folder_name, *, verbose=BIZYAIR_DEBUG):
results.extend(folder_paths.get_filename_list(folder_name))
except:
pass

return results


Expand Down
Loading