Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement Automatic Model List Retrieval for SiliconCloud LLM API Node #135

Merged
merged 14 commits into from
Oct 16, 2024
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions .prettierrc
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
{
"printWidth": 120,
"tabWidth": 4,
"useTabs": false,
"semi": true,
"singleQuote": false,
"quoteProps": "as-needed",
"jsxSingleQuote": false,
"trailingComma": "es5",
"bracketSpacing": true,
"bracketSameLine": false,
"arrowParens": "always",
"requirePragma": false,
"insertPragma": false,
"proseWrap": "preserve",
"htmlWhitespaceSensitivity": "css",
"endOfLine": "lf",
"embeddedLanguageFormatting": "auto"
}
72 changes: 70 additions & 2 deletions js/siliconcloud_llm_api.js
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ import { app } from "../../scripts/app.js";
import { ComfyWidgets } from "../../scripts/widgets.js";

app.registerExtension({
name: "bizyair.siliconcloud.llm.api",
name: "bizyair.siliconcloud.llm.api.populate",
async beforeRegisterNodeDef(nodeType, nodeData, app) {
if (nodeData.name === "BizyAirSiliconCloudLLMAPI") {
function populate(text) {
Expand Down Expand Up @@ -43,4 +43,72 @@ app.registerExtension({
};
}
},
})
});

app.registerExtension({
name: "bizyair.siliconcloud.llm.api.model_fetch",
async beforeRegisterNodeDef(nodeType, nodeData, app) {
if (nodeData.name === "BizyAirSiliconCloudLLMAPI") {
const originalNodeCreated = nodeType.prototype.onNodeCreated;
nodeType.prototype.onNodeCreated = async function () {
if (originalNodeCreated) {
originalNodeCreated.apply(this, arguments);
}

const modelWidget = this.widgets.find((w) => w.name === "model");

const fetchModels = async () => {
try {
const response = await fetch("/bizyair/get_silicon_cloud_models", {
method: "POST",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify({}),
});

if (response.ok) {
const models = await response.json();
console.debug("Fetched models:", models);
return models;
} else {
console.error(`Failed to fetch models: ${response.status}`);
return [];
}
} catch (error) {
console.error(`Error fetching models`, error);
return [];
}
};

const updateModels = async () => {
const prevValue = modelWidget.value;
modelWidget.value = "";
modelWidget.options.values = [];

const models = await fetchModels();

modelWidget.options.values = models;
console.debug("Updated modelWidget.options.values:", modelWidget.options.values);

if (models.includes(prevValue)) {
modelWidget.value = prevValue; // stay on current.
} else if (models.length > 0) {
modelWidget.value = models[0]; // set first as default.
}

console.debug("Updated modelWidget.value:", modelWidget.value);
app.graph.setDirtyCanvas(true);
};

const dummy = async () => {
// calling async method will update the widgets with actual value from the browser and not the default from Node definition.
};

// Initial update
await dummy(); // this will cause the widgets to obtain the actual value from web page.
await updateModels();
};
}
},
});
48 changes: 29 additions & 19 deletions llm.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
import json
import os

import requests
from aiohttp import web
from server import PromptServer

from bizyair.common.env_var import BIZYAIR_SERVER_ADDRESS
from bizyair.image_utils import decode_data, encode_data

Expand All @@ -13,32 +17,38 @@
)


class SiliconCloudLLMAPI:
@PromptServer.instance.routes.post("/bizyair/get_silicon_cloud_models")
async def get_silicon_cloud_models_endpoint(request):
data = await request.json()
api_key = data.get("api_key", get_api_key())
url = "https://api.siliconflow.cn/v1/models"
headers = {"accept": "application/json", "authorization": f"Bearer {api_key}"}
try:
response = requests.get(url, headers=headers)
doombeaker marked this conversation as resolved.
Show resolved Hide resolved
doombeaker marked this conversation as resolved.
Show resolved Hide resolved
response.raise_for_status()
data = response.json()
models = [model["id"] for model in data["data"]]
models.append("No LLM Enhancement")
return web.json_response(models)
except requests.RequestException as e:
print(f"Error fetching models: {e}")
return web.json_response(["Error fetching models"], status=500)

display_name_to_id = {
"Yi1.5 9B": "01-ai/Yi-1.5-9B-Chat-16K",
"DeepSeekV2 Chat": "deepseek-ai/DeepSeek-V2-Chat",
"(Free)GLM4 9B Chat": "THUDM/glm-4-9b-chat",
"Qwen2 72B Instruct": "Qwen/Qwen2-72B-Instruct",
"(Free)Qwen2 7B Instruct": "Qwen/Qwen2-7B-Instruct",
"No LLM Enhancement": "Bypass",
}

class SiliconCloudLLMAPI:
def __init__(self):
pass

@classmethod
def INPUT_TYPES(s):
models = list(s.display_name_to_id.keys())
default_sysmtem_prompt = """你是一个 stable diffusion prompt 专家,为我生成适用于 Stable Diffusion 模型的prompt。
我给你相关的单词,你帮我扩写为适合 Stable Diffusion 文生图的 prompt。要求:
1. 英文输出
2. 除了 prompt 外,不要输出任何其它的信息
"""
default_system_prompt = """你是一个 stable diffusion prompt 专家,为我生成适用于 Stable Diffusion 模型的prompt。 我给你相关的单词,你帮我扩写为适合 Stable Diffusion 文生图的 prompt。要求: 1. 英文输出 2. 除了 prompt 外,不要输出任何其它的信息 """
return {
"required": {
"model": (models, {"default": "(Free)GLM4 9B Chat"}),
"model": ((), {}),
"system_prompt": (
"STRING",
{
"default": default_sysmtem_prompt,
"default": default_system_prompt,
"multiline": True,
"dynamicPrompts": True,
},
Expand Down Expand Up @@ -68,10 +78,10 @@ def INPUT_TYPES(s):
def get_llm_model_response(
self, model, system_prompt, user_prompt, max_tokens, temperature
):
if self.display_name_to_id[model] == "Bypass":
if model == "No LLM Enhancement":
return {"ui": {"text": (user_prompt,)}, "result": (user_prompt,)}
response = get_llm_response(
self.display_name_to_id[model],
model,
system_prompt,
user_prompt,
max_tokens,
Expand Down
Loading