diff --git a/modules/config.py b/modules/config.py index 58107806c..39710921f 100644 --- a/modules/config.py +++ b/modules/config.py @@ -135,33 +135,123 @@ def get_dir_or_set_default(key, default_value): path_outputs = get_dir_or_set_default('path_outputs', '../outputs/') -def get_config_item_or_set_default(key, default_value, validator, disable_empty_as_none=False): +def get_config_item_or_set_default(key, default_value, validator, disable_empty_as_none=False, corrector=None): global config_dict, visited_keys + debug_mode=False + if debug_mode: + print(f"Checking key: {key}") + if key not in visited_keys: visited_keys.append(key) - if key not in config_dict: - config_dict[key] = default_value - return default_value - v = config_dict.get(key, None) + if debug_mode: + print(f"Value for key {key}: {v}") + if not disable_empty_as_none: if v is None or v == '': - v = 'None' + v = default_value + if debug_mode: + print(f"Value for key {key} is None or empty, setting to default: {v}") + if validator(v): - return v + if debug_mode: + print(f"Value for key {key} passed validation.") + elif corrector: + corrected_v = corrector(v) + if validator(corrected_v): + if debug_mode: + print(f"Value for key {key} passed validation after correction.") + v = corrected_v + else: + print(f"Failed to load config key after correction. Using default: {default_value}") + v = default_value else: - if v is not None: - print(f'Failed to load config key: {json.dumps({key:v})} is invalid; will use {json.dumps({key:default_value})} instead.') - config_dict[key] = default_value - return default_value + print(f"Failed to load config key: {json.dumps({key: v})} is invalid. Using default: {default_value}") + v = default_value + + config_dict[key] = v + return v + +def get_model_filenames(folder_path, name_filter=None): + return get_files_from_folder(folder_path, ['.pth', '.ckpt', '.bin', '.safetensors', '.fooocus.patch'], name_filter) + +def update_all_model_names(): + global model_filenames, lora_filenames + model_filenames = get_model_filenames(path_checkpoints) + lora_filenames = get_model_filenames(path_loras) + return + +model_filenames = [] +lora_filenames = [] +update_all_model_names() + +def model_validator(value): + if isinstance(value, str) and (value == "" or value in model_filenames): + return True + else : + print(f"model_filenames: {model_filenames}") # Debug print + print(f"failed model_validator: {value}") # Debug printà + return False + +def correct_case_sensitivity(value, valid_values): + """Corrects case sensitivity of a value or list of values based on a list of valid values. + If a valid value is contained in the input value, it replaces it with the full valid value.""" + def find_full_match(partial_value): + if isinstance(partial_value, str): + lower_partial_value = partial_value.lower() + for valid_value in valid_values: + if lower_partial_value == valid_value.lower(): + return valid_value + elif lower_partial_value in valid_value.lower(): + return valid_value + return partial_value + + print(f"Initial value: {value}") # Debug print + + # If value is a string, apply find_full_match directly + if isinstance(value, str): + corrected_value = find_full_match(value) + print(f"Corrected string value: {corrected_value}") # Debug print + return corrected_value + + # If value is a list + elif isinstance(value, list): + print(f"Processing list value: {value}") # Debug print before if cases + + # Check if value is a list of lists (as in default_loras) + if all(isinstance(item, list) and len(item) == 2 for item in value): + corrected_list = [] + for sub_value in value: + print(f"Processing sub_value: {sub_value}") # Debug print + if isinstance(sub_value, list): + corrected_element = find_full_match(sub_value[0]) + print(f"Correcting {sub_value[0]} to {corrected_element}") # Debug print + corrected_list.append([corrected_element, sub_value[1]]) + else: + corrected_list.append(sub_value) + print(f"Corrected list of lists: {corrected_list}") # Debug print + return corrected_list + + # If value is a regular list + else: + corrected_list = [find_full_match(val) for val in value] + print(f"Corrected regular list: {corrected_list}") # Debug print + return corrected_list + + return value + +def model_corrector(value): + return correct_case_sensitivity(value, model_filenames) + default_base_model_name = get_config_item_or_set_default( key='default_model', default_value='model.safetensors', - validator=lambda x: isinstance(x, str) + validator=model_validator, + corrector=model_corrector ) previous_default_models = get_config_item_or_set_default( key='previous_default_models', @@ -171,19 +261,40 @@ def get_config_item_or_set_default(key, default_value, validator, disable_empty_ default_refiner_model_name = get_config_item_or_set_default( key='default_refiner', default_value='None', - validator=lambda x: isinstance(x, str) + validator=model_validator, + corrector=model_corrector ) default_refiner_switch = get_config_item_or_set_default( key='default_refiner_switch', default_value=0.8, validator=lambda x: isinstance(x, numbers.Number) and 0 <= x <= 1 ) + +def loras_validator(x): + if not isinstance(x, list): + print(f"Validation failed: 'x' is not a list. Value of x: {x}") # Debug print + return False + + for y in x: + if not (len(y) == 2 and isinstance(y[0], str) and isinstance(y[1], (numbers.Number, float))): + print(f"Validation failed: Element structure is incorrect. Element: {y}") # Debug print + return False + if y[0] != "None" and y[0] not in lora_filenames: + print(f"Validation failed: Lora filename not found in lora_filenames. Lora filename: {y[0]}") # Debug print + print(f"Available lora_filenames: {lora_filenames}") # Debug print + return False + + return True + +def loras_corrector(value): + return correct_case_sensitivity(value, lora_filenames) + default_loras = get_config_item_or_set_default( key='default_loras', default_value=[ [ - "None", - 1.0 + "sd_xl_offset_example-lora_1.0.safetensors", + 0.1 ], [ "None", @@ -224,6 +335,24 @@ def get_config_item_or_set_default(key, default_value, validator, disable_empty_ default_value='karras', validator=lambda x: x in modules.flags.scheduler_list ) + +def sdxl_styles_validator(x): + if not isinstance(x, list): + print("Validation failed: The variable 'x' is not a list.") + print(f"Type of x: {type(x)}") + return False + + for y in x: + if y not in modules.sdxl_styles.legal_style_names: + print("Validation failed: An element in 'x' is not in legal_style_names.") + print(f"Failed element: {y}") + return False + + return True + +def sdxl_styles_corrector(value): + return correct_case_sensitivity(value, modules.sdxl_styles.legal_style_names) + default_styles = get_config_item_or_set_default( key='default_styles', default_value=[ @@ -231,7 +360,8 @@ def get_config_item_or_set_default(key, default_value, validator, disable_empty_ "Fooocus Enhance", "Fooocus Sharp" ], - validator=lambda x: isinstance(x, list) and all(y in modules.sdxl_styles.legal_style_names for y in x) + validator=sdxl_styles_validator, + corrector=sdxl_styles_corrector ) default_prompt_negative = get_config_item_or_set_default( key='default_prompt_negative', @@ -388,20 +518,6 @@ def add_ratio(x): os.makedirs(path_outputs, exist_ok=True) -model_filenames = [] -lora_filenames = [] - - -def get_model_filenames(folder_path, name_filter=None): - return get_files_from_folder(folder_path, ['.pth', '.ckpt', '.bin', '.safetensors', '.fooocus.patch'], name_filter) - - -def update_all_model_names(): - global model_filenames, lora_filenames - model_filenames = get_model_filenames(path_checkpoints) - lora_filenames = get_model_filenames(path_loras) - return - def downloading_inpaint_models(v): assert v in modules.flags.inpaint_engine_versions @@ -514,5 +630,3 @@ def downloading_upscale_model(): ) return os.path.join(path_upscale_models, 'fooocus_upscaler_s409985e5.bin') - -update_all_model_names()