Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Flux Slider Compatibilty with ComfyUI #122

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
venv
outputs
datasets
__pycache__
inputs
217 changes: 217 additions & 0 deletions analysis_scripts/convert_pt_to_st.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,217 @@
import torch
import safetensors.torch
from pathlib import Path
import numpy as np
import json

# Global dictionary to store name mappings
layer_name_mappings = {}

def analyze_state_dict(state_dict, file_name):
print(f"\n=== Analysis for {file_name} ===")

# Get all layer names
print("\n1. Layer Structure:")
print("-------------------")
cnt = 0
for key in state_dict.keys():
tensor = state_dict[key]
print(key,tensor.shape, tensor.dtype)
cnt += 1

print(f"\nTotal layers: {cnt}")
# Analyze LoRA layers specifically
print("\n2. LoRA Weights Analysis:")
print("------------------------")
lora_up_layers = {k: v for k, v in state_dict.items() if 'lora_up' in k}
lora_down_layers = {k: v for k, v in state_dict.items() if 'lora_down' in k}

print(f"\nFound {len(lora_up_layers)} LoRA up layers and {len(lora_down_layers)} LoRA down layers")

# # Analyze each LoRA pair
# for up_key in lora_up_layers:
# base_key = up_key.replace('lora_up', '')
# down_key = up_key.replace('lora_up', 'lora_down')

# if down_key in lora_down_layers:
# up_weight = state_dict[up_key]
# down_weight = state_dict[down_key]

# print(f"\nLoRA Pair Analysis for {base_key}")
# print("-" * 40)

# # Compute effective weight (without scaling)
# effective_weight = (up_weight @ down_weight)

# stats = {
# 'up': {
# 'min': up_weight.min().item(),
# 'max': up_weight.max().item(),
# 'mean': up_weight.mean().item(),
# 'std': up_weight.std().item()
# },
# 'down': {
# 'min': down_weight.min().item(),
# 'max': down_weight.max().item(),
# 'mean': down_weight.mean().item(),
# 'std': down_weight.std().item()
# },
# 'effective': {
# 'min': effective_weight.min().item(),
# 'max': effective_weight.max().item(),
# 'mean': effective_weight.mean().item(),
# 'std': effective_weight.std().item()
# }
# }

# print(f"Up weights ({up_key}):")
# print(f" Shape: {up_weight.shape}")
# print(f" Range: [{stats['up']['min']:.6f}, {stats['up']['max']:.6f}]")
# print(f" Mean: {stats['up']['mean']:.6f}")
# print(f" Std: {stats['up']['std']:.6f}")

# print(f"\nDown weights ({down_key}):")
# print(f" Shape: {down_weight.shape}")
# print(f" Range: [{stats['down']['min']:.6f}, {stats['down']['max']:.6f}]")
# print(f" Mean: {stats['down']['mean']:.6f}")
# print(f" Std: {stats['down']['std']:.6f}")

# print(f"\nEffective weights (up @ down):")
# print(f" Shape: {effective_weight.shape}")
# print(f" Range: [{stats['effective']['min']:.6f}, {stats['effective']['max']:.6f}]")
# print(f" Mean: {stats['effective']['mean']:.6f}")
# print(f" Std: {stats['effective']['std']:.6f}")

def convert_layer_name(old_name):
"""Convert layer names from flux_slider format to flux_ostris format."""
# If we've seen this name before, return the cached conversion
if old_name in layer_name_mappings:
return layer_name_mappings[old_name]

new_name = old_name

# Handle transformer blocks
if old_name.startswith('lora_unet_transformer_blocks_'):

parts = old_name.split('_')
block_num = parts[4] # Get block number
# print(parts)
# Convert lora weights
if 'lora_up' in old_name:
new_name = old_name.replace('lora_up', 'lora_B')
elif 'lora_down' in old_name:
new_name = old_name.replace('lora_down', 'lora_A')

# print(new_name, block_num)
# Replace prefix and convert to dot notation
new_name = new_name.replace(
f'lora_unet_transformer_blocks_{block_num}_attn_to_',
f'transformer.transformer_blocks.{block_num}.attn.to_'
).replace(
f'lora_unet_transformer_blocks_{block_num}_attn_add_',
f'transformer.transformer_blocks.{block_num}.attn.add_'
)
# print(new_name)

# Handle single transformer blocks
elif old_name.startswith('lora_unet_single_transformer_blocks_'):
parts = old_name.split('_')
block_num = parts[5] # Get block numbes
# Convert lora weights
if 'lora_up' in old_name:
new_name = old_name.replace('lora_up', 'lora_B')
elif 'lora_down' in old_name:
new_name = old_name.replace('lora_down', 'lora_A')

print(new_name)
# Replace prefix and convert to dot notation
new_name = new_name.replace(
f'lora_unet_single_transformer_blocks_{block_num}_attn_to_',
f'transformer.single_transformer_blocks.{block_num}.attn.to_'
).replace(
f'lora_unet_single_transformer_blocks_{block_num}_attn_add_',
f'transformer.single_transformer_blocks.{block_num}.attn.add_'
)
# Store mapping
layer_name_mappings[old_name] = new_name
return new_name

def save_name_mappings(output_path):
"""Save the layer name mappings to a JSON file."""
with open(output_path, 'w') as f:
json.dump(layer_name_mappings, f, indent=2)
print(f"\nSaved layer name mappings to {output_path}")

def convert_pt_to_safetensors(pt_path, output_path=None, analyze=True):
# Load the .pt file
state_dict = torch.load(pt_path)

# Create new state dict with converted names
new_state_dict = {}
for key, value in state_dict.items():
new_key = convert_layer_name(key)
new_state_dict[new_key] = value
# raise Exception(f"new_key: {new_key}")

# # Analyze the state dict if requested
# if analyze:
# analyze_state_dict(new_state_dict, Path(pt_path).name)

# If output path is not specified, use the same name but with .safetensors extension
if output_path is None:
output_path = str(Path(pt_path).with_suffix('.safetensors'))

# Save as safetensors
#print 5 keys of the new_state_dict
print(list(new_state_dict.keys())[:5])
safetensors.torch.save_file(new_state_dict, output_path)
print(f"\nConverted {pt_path} to {output_path}")

def analyze_safetensors_file(safetensors_path):
# Load the safetensors file
state_dict = safetensors.torch.load_file(safetensors_path)
analyze_state_dict(state_dict, Path(safetensors_path).name)

if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--input_path", type=str, required=True, help="Path to .pt or .safetensors file or directory")
parser.add_argument("--output_path", type=str, help="Output path (optional)")
parser.add_argument("--analyze_only", action="store_true", help="Only analyze without converting")
args = parser.parse_args()

input_path = Path(args.input_path)

if input_path.is_file():
if args.analyze_only:
if input_path.suffix == '.pt':
state_dict = torch.load(str(input_path))
elif input_path.suffix == '.safetensors':
state_dict = safetensors.torch.load_file(str(input_path))
else:
raise ValueError("Input file must be either .pt or .safetensors")
analyze_state_dict(state_dict, input_path.name)
else:
if input_path.suffix == '.pt':
convert_pt_to_safetensors(str(input_path), args.output_path)
else:
print("Input file is already in safetensors format")
elif input_path.is_dir():
for file in input_path.glob("*.{pt,safetensors}"):
if args.analyze_only:
if file.suffix == '.pt':
state_dict = torch.load(str(file))
else:
state_dict = safetensors.torch.load_file(str(file))
analyze_state_dict(state_dict, file.name)
else:
if file.suffix == '.pt':
convert_pt_to_safetensors(str(file))
else:
print(f"Skipping {file} as it's already in safetensors format")

if not args.analyze_only:
save_name_mappings('outputs/layer_name_mappings.json')

# python analysis_scripts/convert_pt_to_st.py --input_path flux-sliders/outputs/person-obese-mod/slider_0.pt --output_path outputs/person-obse-mode.safetensors
# python analysis_scripts/convert_pt_to_st.py --input_path outputs/person-obse-mode.safetensors --analyze_only > outputs/person-obse-mode-layers.txt
69 changes: 69 additions & 0 deletions analysis_scripts/heirarchy_breakdown.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import re
from collections import defaultdict

def analyze_lora_layers(filepath):
# Dictionary to store parameter counts for each block and sub-block
param_counts = defaultdict(lambda: defaultdict(int))

# Regular expression to parse layer names and shapes
layer_pattern = r'(lora_\w+)_(\d+)?_?(\w+)?'

with open(filepath, 'r') as f:
for line in f:
if 'torch.Size' not in line:
continue

# Parse line into name and shape
parts = line.strip().split(' ')
layer_name = parts[0]
shape_str = parts[1]

# Extract shape dimensions
shape = ""#eval(shape_str)

# Calculate parameters (multiply all dimensions)
params = 1
for dim in shape:
params *= dim if isinstance(dim, int) else 1

# Parse layer name to get block hierarchy
name_parts = layer_name.split('.')
base_name = name_parts[0]

# Extract block and sub-block information
match = re.match(layer_pattern, base_name)
if match:
main_block = match.group(1)
block_num = match.group(2)
block_type = match.group(3)

# Update parameter counts
if block_num:
full_block = f"{main_block}_{block_num}"
param_counts[main_block][block_num] += params
else:
param_counts[main_block]["total"] += params

# Generate detailed report
report = []
total_params = 0

for main_block, sub_blocks in param_counts.items():
block_total = sum(sub_blocks.values())
total_params += block_total

report.append(f"{main_block} - {block_total}")

# Add sub-block details
for sub_block, params in sub_blocks.items():
if sub_block != "total":
report.append(f" {main_block}_{sub_block} - {params}")

report.append(f"\nTotal Parameters: {total_params}")

return "\n".join(report)

# Usage
filepath = "outputs/smiling_xl_layers_info.txt"
detailed_breakdown = analyze_lora_layers(filepath)
print(detailed_breakdown)
43 changes: 43 additions & 0 deletions flux-sliders/civit-download.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
#!/bin/bash

# Check if correct number of arguments provided
if [ "$#" -ne 2 ]; then
echo "Usage: $0 <model_url> <output_filename>"
echo "Example: $0 'https://civitai.com/api/download/models/802821?type=Model&format=SafeTensor' 'my_model.safetensor'"
exit 1
fi

# Load environment variables from .env file
if [ -f ~/.env ]; then
export $(cat ~/.env | xargs)
fi

# Check if API key is available
if [ -z "$CIVIT_API_KEY" ]; then
echo "Error: CIVIT_API_KEY not found in .env file"
exit 1
fi

# Get arguments and convert URL to API format
INPUT_URL="$1"
# Expand ~ to full home directory path and make it absolute
OUTPUT_FILE=$(realpath -m "${2/#\~/$HOME}")

# Extract the model ID from the URL
MODEL_ID=$(echo "$INPUT_URL" | grep -o 'models/[0-9]*' | cut -d'/' -f2)
MODEL_URL="https://civitai.com/api/download/models/${MODEL_ID}"

# Create output directory if it doesn't exist
OUTPUT_DIR=$(dirname "$OUTPUT_FILE")
mkdir -p "$OUTPUT_DIR"

# Download the file
echo "Downloading model to $OUTPUT_FILE..."
curl -L -H "Authorization: Bearer $CIVIT_API_KEY" "$MODEL_URL" --output "$OUTPUT_FILE"

if [ $? -eq 0 ]; then
echo "Download complete!"
else
echo "Download failed!"
exit 1
fi
Loading