Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add SharedLoraLoader node #154

Merged
merged 4 commits into from
Sep 29, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 44 additions & 0 deletions js/share_lora_loader.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import { api } from "../../../scripts/api.js";
import { app } from "../../scripts/app.js";
app.registerExtension({
name: "bizyair.siliconcloud.share.lora.loader",
async beforeRegisterNodeDef(nodeType, nodeData, app) {
if (nodeData.name === "BizyAir_SharedLoraLoader") {
async function onTextChange(share_id, canvas, comfynode) {
console.log("share_id:", share_id);
const response = await api.fetchApi(`/bizyair/modelhost/${share_id}/models/files?type=bizyair/lora`, {
method: "GET",
headers: {
"Content-Type": "application/json",
},
});

const { data: loras_list } = await response.json();
// console.log("loras_list:", loras_list);
ccssu marked this conversation as resolved.
Show resolved Hide resolved
const lora_name_widget = comfynode.widgets.find(widget => widget.name === "lora_name");
if (loras_list.length > 0) {
lora_name_widget.value = loras_list[0];
lora_name_widget.options.values = loras_list;
} else {
console.log("No loras found in the response");
lora_name_widget.value = "";
lora_name_widget.options.values = [];
}
}

function setWigetCallback(){
const shareid_widget = this.widgets.find(widget => widget.name === "share_id");
if (shareid_widget) {
shareid_widget.callback = onTextChange;
} else {
console.log("share_id widget not found");
}
}
const onNodeCreated = nodeType.prototype.onNodeCreated
nodeType.prototype.onNodeCreated = function () {
onNodeCreated?.apply(this, arguments);
setWigetCallback.call(this, arguments);
};
}
},
})
50 changes: 50 additions & 0 deletions nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -798,3 +798,53 @@ def INPUT_TYPES(s):
# FUNCTION = "encode"

CATEGORY = "conditioning/inpaint"


class SharedLoraLoader(BizyAir_LoraLoader):
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"share_id": ("STRING", {"default": "share_id"}),
"lora_name": ([],),
"model": (data_types.MODEL,),
"clip": (data_types.CLIP,),
"strength_model": (
"FLOAT",
{"default": 1.0, "min": -100.0, "max": 100.0, "step": 0.01},
),
"strength_clip": (
"FLOAT",
{"default": 1.0, "min": -100.0, "max": 100.0, "step": 0.01},
),
}
}

RETURN_TYPES = (data_types.MODEL, data_types.CLIP)
RETURN_NAMES = ("MODEL", "CLIP")
FUNCTION = "shared_load_lora"
CATEGORY = f"{PREFIX}/loaders"
NODE_DISPLAY_NAME = "Shared Lora Loader"

@classmethod
def VALIDATE_INPUTS(cls, share_id: str, lora_name: str):
if lora_name in folder_paths.filename_path_mapping.get("loras", {}):
return True

outs = folder_paths.get_share_filename_list("loras", share_id=share_id)
if lora_name not in outs:
raise ValueError(
f"Lora {lora_name} not found in share {share_id} with {outs}"
)
return True

def shared_load_lora(
self, model, clip, lora_name, strength_model, strength_clip, **kwargs
):
return super().load_lora(
model=model,
clip=clip,
lora_name=lora_name,
strength_model=strength_model,
strength_clip=strength_clip,
)
59 changes: 29 additions & 30 deletions src/bizy_server/modelhost.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,10 @@ async def submit_upload(request):
self.uploads[upload_id]["type"] = json_data["type"]
self.uploads[upload_id]["name"] = json_data["name"]
self.upload_queue.put(self.uploads[upload_id])

# enable refresh for lora
# TODO: enable refresh for other types
bizyair.path_utils.path_manager.enable_refresh_options("loras")
return OKResponse(None)

@prompt_server.routes.get(f"/{API_PREFIX}/models/files")
Expand Down Expand Up @@ -270,6 +274,7 @@ async def list_model_files(request):
@prompt_server.routes.get(f"/{API_PREFIX}" + "/{shareId}/models/files")
async def list_share_model_files(request):
shareId = request.match_info["shareId"]

if not self.is_string_valid(shareId):
return ErrResponse(INVALID_SHARE_ID)

Expand All @@ -286,12 +291,12 @@ async def list_share_model_files(request):

if "ext_name" in request.rel_url.query:
payload["ext_name"] = request.rel_url.query["ext_name"]

model_files, err = await self.get_share_model_files(
shareId=shareId, payload=payload
)
if err is not None:
return ErrResponse(err)

return OKResponse(model_files)

@prompt_server.routes.delete(f"/{API_PREFIX}/models")
Expand Down Expand Up @@ -524,41 +529,35 @@ async def get_model_files(self, payload) -> (dict, ErrorNo):
return result, None

async def get_share_model_files(self, shareId, payload) -> (dict, ErrorNo):
headers, err = self.auth_header()
if err is not None:
return None, err

server_url = f"{BIZYAIR_SERVER_ADDRESS}/{shareId}/models/files"
try:
resp = self.do_get(server_url, params=payload, headers=headers)
ret = json.loads(resp)
if ret["code"] != CODE_OK:
if ret["code"] == CODE_NO_MODEL_FOUND:
return [], None
else:
return None, ErrorNo(500, ret["code"], None, ret["message"])

if not ret["data"]:
return [], None
except Exception as e:
print(f"fail to list share model files: {str(e)}")
return None, LIST_SHARE_MODEL_FILE_ERR
def callback(ret: dict):
if ret["code"] != CODE_OK:
if ret["code"] == CODE_NO_MODEL_FOUND:
return [], None
else:
return [], ErrorNo(500, ret["code"], None, ret["message"])

files = ret["data"]["files"]
result = []
if len(files) > 0:
tree = defaultdict(lambda: {"name": "", "list": []})
if not ret or "data" not in ret or ret["data"] is None:
return [], None

for item in files:
parts = item["label_path"].split("/")
model_name = parts[0]
if model_name not in tree:
tree[model_name] = {"name": model_name, "list": [item]}
else:
tree[model_name]["list"].append(item)
result = list(tree.values())
outputs = [
x["label_path"] for x in ret["data"]["files"] if x["label_path"]
]
outputs = bizyair.path_utils.filter_files_extensions(
outputs,
extensions=bizyair.path_utils.path_manager.supported_pt_extensions,
)
return outputs, None

return result, None
ret = await bizyair.common.client.async_send_request(
method="GET", url=server_url, params=payload, callback=callback
)
return ret[0], ret[1]
except Exception as e:
print(f"fail to list share model files: {str(e)}")
return [], LIST_SHARE_MODEL_FILE_ERR

async def get_models(self, payload) -> (dict, ErrorNo):
headers, err = self.auth_header()
Expand Down
45 changes: 45 additions & 0 deletions src/bizyair/common/client.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import asyncio
import json
import pprint
import urllib.error
import urllib.request
import warnings

import aiohttp

__all__ = ["send_request"]

from dataclasses import dataclass, field
Expand Down Expand Up @@ -134,6 +137,48 @@ def send_request(
return json.loads(response_data)


async def async_send_request(
method: str = "POST",
url: str = None,
data: bytes = None,
verbose=False,
callback: callable = process_response_data,
**kwargs,
) -> dict:
headers = kwargs.pop("headers") if "headers" in kwargs else _headers()
try:
async with aiohttp.ClientSession() as session:
async with session.request(
method, url, data=data, headers=headers, **kwargs
) as response:
response_data = await response.text()
if response.status != 200:
error_message = f"HTTP Status {response.status}"
if verbose:
print(f"Error encountered: {error_message}")
if response.status == 401:
raise PermissionError(
"Key is invalid, please refer to https://cloud.siliconflow.cn to get the API key.\n"
"If you have the key, please click the 'BizyAir Key' button at the bottom right to set the key."
)
else:
raise ConnectionError(
f"Failed to connect to the server: {error_message}.\n"
+ "Please check your API key and ensure the server is reachable.\n"
+ "Also, verify your network settings and disable any proxies if necessary.\n"
+ "After checking, please restart the ComfyUI service."
)
if callback:
return callback(json.loads(response_data))
return json.loads(response_data)
except aiohttp.ClientError as e:
print(f"Error fetching data: {e}")
return {}
except Exception as e:
print(f"Error fetching data: {str(e)}")
return {}


def fetch_models_by_type(
url: str, model_type: str, *, method="GET", verbose=False
) -> dict:
Expand Down
3 changes: 3 additions & 0 deletions src/bizyair/path_utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from .path_manager import (
convert_prompt_label_path_to_real_path,
disable_refresh_options,
enable_refresh_options,
get_filename_list,
guess_config,
guess_url_from_node,
)
from .utils import filter_files_extensions
73 changes: 62 additions & 11 deletions src/bizyair/path_utils/path_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import pprint
import re
import warnings
from dataclasses import dataclass
from typing import Any, Dict, List, Union

from ..common import fetch_models_by_type
Expand All @@ -24,6 +25,34 @@
filename_path_mapping: dict[str, dict[str, str]] = {}


@dataclass
class RefreshSettings:
loras: bool = True

def get(self, folder_name: str, default: bool = True):
return getattr(self, folder_name, default)

def set(self, folder_name: str, value: bool):
setattr(self, folder_name, value)


refresh_settings = RefreshSettings()


def enable_refresh_options(folder_names: Union[str, list[str]]):
if isinstance(folder_names, str):
folder_names = [folder_names]
for folder_name in folder_names:
refresh_settings.set(folder_name, True)


def disable_refresh_options(folder_names: Union[str, list[str]]):
if isinstance(folder_names, str):
folder_names = [folder_names]
for folder_name in folder_names:
refresh_settings.set(folder_name, False)


def _get_config_path():
src_bizyair_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
configs_path = os.path.join(src_bizyair_path, "configs")
Expand Down Expand Up @@ -88,26 +117,37 @@ def get_config_file_list(base_path=None) -> list:


def cached_filename_list(
folder_name: str, *, verbose=False, refresh=False
folder_name: str, *, share_id: str = None, verbose=False, refresh=False
) -> list[str]:
global filename_path_mapping
if refresh or folder_name not in filename_path_mapping:
model_types: Dict[str, str] = models_config["model_types"]
url = get_service_route(models_config["model_hub"]["find_model"])
if share_id:
url = f"{BIZYAIR_SERVER_ADDRESS}/{share_id}/models/files"
else:
url = get_service_route(models_config["model_hub"]["find_model"])
msg = fetch_models_by_type(
url=url, method="GET", model_type=model_types[folder_name]
)
if verbose:
pprint.pprint({"cached_filename_list": msg})

if not msg or "data" not in msg or msg["data"] is None:
try:
if not msg or "data" not in msg or msg["data"] is None:
return []

filename_path_mapping[folder_name] = {
x["label_path"]: x["real_path"]
for x in msg["data"]["files"]
if x["label_path"]
}
except Exception as e:
warnings.warn(f"Failed to get filename list: {e}")
return []

filename_path_mapping[folder_name] = {
x["label_path"]: x["real_path"]
for x in msg["data"]["files"]
if x["label_path"]
}
finally:
# TODO fix share_id vaild refresh settings
if share_id is None:
disable_refresh_options(folder_name)

return list(
filter_files_extensions(
Expand Down Expand Up @@ -139,11 +179,23 @@ def convert_prompt_label_path_to_real_path(prompt: dict[str, dict[str, any]]) ->
return new_prompt


def get_share_filename_list(folder_name, share_id, *, verbose=BIZYAIR_DEBUG):
assert share_id is not None and isinstance(share_id, str)
# TODO fix share_id vaild refresh settings
return cached_filename_list(
folder_name, share_id=share_id, verbose=verbose, refresh=True
)


def get_filename_list(folder_name, *, verbose=BIZYAIR_DEBUG):

global folder_names_and_paths
results = []
if folder_name in models_config["model_types"]:
results.extend(cached_filename_list(folder_name, verbose=verbose, refresh=True))
refresh = refresh_settings.get(folder_name, True)
ccssu marked this conversation as resolved.
Show resolved Hide resolved
results.extend(
cached_filename_list(folder_name, verbose=verbose, refresh=refresh)
)
if folder_name in folder_names_and_paths:
results.extend(folder_names_and_paths[folder_name])
if BIZYAIR_DEBUG:
Expand All @@ -153,7 +205,6 @@ def get_filename_list(folder_name, *, verbose=BIZYAIR_DEBUG):
results.extend(folder_paths.get_filename_list(folder_name))
except:
pass

return results


Expand Down
Loading