Skip to content

Commit

Permalink
feat: parse env var strings to expected config value types (#3107)
Browse files Browse the repository at this point in the history
* fix: add try_parse_bool for env var strings to enable config overrides of boolean values

* fix: fallback to given value if not parseable

* feat: extend eval to all valid types

* fix: remove return type

* fix: prevent strange type conversions by providing expected type

* feat: add tests
  • Loading branch information
mashb1t authored Jun 6, 2024
1 parent 04d7648 commit 5abae22
Show file tree
Hide file tree
Showing 3 changed files with 168 additions and 39 deletions.
118 changes: 79 additions & 39 deletions modules/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@
import json
import math
import numbers

import args_manager
import tempfile
import modules.flags
import modules.sdxl_styles

from modules.model_loader import load_file_from_url
from modules.extra_utils import makedirs_with_log, get_files_from_folder
from modules.extra_utils import makedirs_with_log, get_files_from_folder, try_eval_env_var
from modules.flags import OutputFormat, Performance, MetadataScheme


Expand Down Expand Up @@ -200,14 +201,15 @@ def get_dir_or_set_default(key, default_value, as_array=False, make_directory=Fa
path_outputs = get_path_output()


def get_config_item_or_set_default(key, default_value, validator, disable_empty_as_none=False):
def get_config_item_or_set_default(key, default_value, validator, disable_empty_as_none=False, expected_type=None):
global config_dict, visited_keys

if key not in visited_keys:
visited_keys.append(key)

v = os.getenv(key)
if v is not None:
v = try_eval_env_var(v, expected_type)
print(f"Environment: {key} = {v}")
config_dict[key] = v

Expand Down Expand Up @@ -252,41 +254,49 @@ def init_temp_path(path: str | None, default_path: str) -> str:
key='temp_path',
default_value=default_temp_path,
validator=lambda x: isinstance(x, str),
expected_type=str
), default_temp_path)
temp_path_cleanup_on_launch = get_config_item_or_set_default(
key='temp_path_cleanup_on_launch',
default_value=True,
validator=lambda x: isinstance(x, bool)
validator=lambda x: isinstance(x, bool),
expected_type=bool
)
default_base_model_name = default_model = get_config_item_or_set_default(
key='default_model',
default_value='model.safetensors',
validator=lambda x: isinstance(x, str)
validator=lambda x: isinstance(x, str),
expected_type=str
)
previous_default_models = get_config_item_or_set_default(
key='previous_default_models',
default_value=[],
validator=lambda x: isinstance(x, list) and all(isinstance(k, str) for k in x)
validator=lambda x: isinstance(x, list) and all(isinstance(k, str) for k in x),
expected_type=list
)
default_refiner_model_name = default_refiner = get_config_item_or_set_default(
key='default_refiner',
default_value='None',
validator=lambda x: isinstance(x, str)
validator=lambda x: isinstance(x, str),
expected_type=str
)
default_refiner_switch = get_config_item_or_set_default(
key='default_refiner_switch',
default_value=0.8,
validator=lambda x: isinstance(x, numbers.Number) and 0 <= x <= 1
validator=lambda x: isinstance(x, numbers.Number) and 0 <= x <= 1,
expected_type=numbers.Number
)
default_loras_min_weight = get_config_item_or_set_default(
key='default_loras_min_weight',
default_value=-2,
validator=lambda x: isinstance(x, numbers.Number) and -10 <= x <= 10
validator=lambda x: isinstance(x, numbers.Number) and -10 <= x <= 10,
expected_type=numbers.Number
)
default_loras_max_weight = get_config_item_or_set_default(
key='default_loras_max_weight',
default_value=2,
validator=lambda x: isinstance(x, numbers.Number) and -10 <= x <= 10
validator=lambda x: isinstance(x, numbers.Number) and -10 <= x <= 10,
expected_type=numbers.Number
)
default_loras = get_config_item_or_set_default(
key='default_loras',
Expand Down Expand Up @@ -320,38 +330,45 @@ def init_temp_path(path: str | None, default_path: str) -> str:
validator=lambda x: isinstance(x, list) and all(
len(y) == 3 and isinstance(y[0], bool) and isinstance(y[1], str) and isinstance(y[2], numbers.Number)
or len(y) == 2 and isinstance(y[0], str) and isinstance(y[1], numbers.Number)
for y in x)
for y in x),
expected_type=list
)
default_loras = [(y[0], y[1], y[2]) if len(y) == 3 else (True, y[0], y[1]) for y in default_loras]
default_max_lora_number = get_config_item_or_set_default(
key='default_max_lora_number',
default_value=len(default_loras) if isinstance(default_loras, list) and len(default_loras) > 0 else 5,
validator=lambda x: isinstance(x, int) and x >= 1
validator=lambda x: isinstance(x, int) and x >= 1,
expected_type=int
)
default_cfg_scale = get_config_item_or_set_default(
key='default_cfg_scale',
default_value=7.0,
validator=lambda x: isinstance(x, numbers.Number)
validator=lambda x: isinstance(x, numbers.Number),
expected_type=numbers.Number
)
default_sample_sharpness = get_config_item_or_set_default(
key='default_sample_sharpness',
default_value=2.0,
validator=lambda x: isinstance(x, numbers.Number)
validator=lambda x: isinstance(x, numbers.Number),
expected_type=numbers.Number
)
default_sampler = get_config_item_or_set_default(
key='default_sampler',
default_value='dpmpp_2m_sde_gpu',
validator=lambda x: x in modules.flags.sampler_list
validator=lambda x: x in modules.flags.sampler_list,
expected_type=str
)
default_scheduler = get_config_item_or_set_default(
key='default_scheduler',
default_value='karras',
validator=lambda x: x in modules.flags.scheduler_list
validator=lambda x: x in modules.flags.scheduler_list,
expected_type=str
)
default_vae = get_config_item_or_set_default(
key='default_vae',
default_value=modules.flags.default_vae,
validator=lambda x: isinstance(x, str)
validator=lambda x: isinstance(x, str),
expected_type=str
)
default_styles = get_config_item_or_set_default(
key='default_styles',
Expand All @@ -360,121 +377,144 @@ def init_temp_path(path: str | None, default_path: str) -> str:
"Fooocus Enhance",
"Fooocus Sharp"
],
validator=lambda x: isinstance(x, list) and all(y in modules.sdxl_styles.legal_style_names for y in x)
validator=lambda x: isinstance(x, list) and all(y in modules.sdxl_styles.legal_style_names for y in x),
expected_type=list
)
default_prompt_negative = get_config_item_or_set_default(
key='default_prompt_negative',
default_value='',
validator=lambda x: isinstance(x, str),
disable_empty_as_none=True
disable_empty_as_none=True,
expected_type=str
)
default_prompt = get_config_item_or_set_default(
key='default_prompt',
default_value='',
validator=lambda x: isinstance(x, str),
disable_empty_as_none=True
disable_empty_as_none=True,
expected_type=str
)
default_performance = get_config_item_or_set_default(
key='default_performance',
default_value=Performance.SPEED.value,
validator=lambda x: x in Performance.list()
validator=lambda x: x in Performance.list(),
expected_type=str
)
default_advanced_checkbox = get_config_item_or_set_default(
key='default_advanced_checkbox',
default_value=False,
validator=lambda x: isinstance(x, bool)
validator=lambda x: isinstance(x, bool),
expected_type=bool
)
default_max_image_number = get_config_item_or_set_default(
key='default_max_image_number',
default_value=32,
validator=lambda x: isinstance(x, int) and x >= 1
validator=lambda x: isinstance(x, int) and x >= 1,
expected_type=int
)
default_output_format = get_config_item_or_set_default(
key='default_output_format',
default_value='png',
validator=lambda x: x in OutputFormat.list()
validator=lambda x: x in OutputFormat.list(),
expected_type=str
)
default_image_number = get_config_item_or_set_default(
key='default_image_number',
default_value=2,
validator=lambda x: isinstance(x, int) and 1 <= x <= default_max_image_number
validator=lambda x: isinstance(x, int) and 1 <= x <= default_max_image_number,
expected_type=int
)
checkpoint_downloads = get_config_item_or_set_default(
key='checkpoint_downloads',
default_value={},
validator=lambda x: isinstance(x, dict) and all(isinstance(k, str) and isinstance(v, str) for k, v in x.items())
validator=lambda x: isinstance(x, dict) and all(isinstance(k, str) and isinstance(v, str) for k, v in x.items()),
expected_type=dict
)
lora_downloads = get_config_item_or_set_default(
key='lora_downloads',
default_value={},
validator=lambda x: isinstance(x, dict) and all(isinstance(k, str) and isinstance(v, str) for k, v in x.items())
validator=lambda x: isinstance(x, dict) and all(isinstance(k, str) and isinstance(v, str) for k, v in x.items()),
expected_type=dict
)
embeddings_downloads = get_config_item_or_set_default(
key='embeddings_downloads',
default_value={},
validator=lambda x: isinstance(x, dict) and all(isinstance(k, str) and isinstance(v, str) for k, v in x.items())
validator=lambda x: isinstance(x, dict) and all(isinstance(k, str) and isinstance(v, str) for k, v in x.items()),
expected_type=dict
)
available_aspect_ratios = get_config_item_or_set_default(
key='available_aspect_ratios',
default_value=modules.flags.sdxl_aspect_ratios,
validator=lambda x: isinstance(x, list) and all('*' in v for v in x) and len(x) > 1
validator=lambda x: isinstance(x, list) and all('*' in v for v in x) and len(x) > 1,
expected_type=list
)
default_aspect_ratio = get_config_item_or_set_default(
key='default_aspect_ratio',
default_value='1152*896' if '1152*896' in available_aspect_ratios else available_aspect_ratios[0],
validator=lambda x: x in available_aspect_ratios
validator=lambda x: x in available_aspect_ratios,
expected_type=str
)
default_inpaint_engine_version = get_config_item_or_set_default(
key='default_inpaint_engine_version',
default_value='v2.6',
validator=lambda x: x in modules.flags.inpaint_engine_versions
validator=lambda x: x in modules.flags.inpaint_engine_versions,
expected_type=str
)
default_cfg_tsnr = get_config_item_or_set_default(
key='default_cfg_tsnr',
default_value=7.0,
validator=lambda x: isinstance(x, numbers.Number)
validator=lambda x: isinstance(x, numbers.Number),
expected_type=numbers.Number
)
default_clip_skip = get_config_item_or_set_default(
key='default_clip_skip',
default_value=2,
validator=lambda x: isinstance(x, int) and 1 <= x <= modules.flags.clip_skip_max
validator=lambda x: isinstance(x, int) and 1 <= x <= modules.flags.clip_skip_max,
expected_type=int
)
default_overwrite_step = get_config_item_or_set_default(
key='default_overwrite_step',
default_value=-1,
validator=lambda x: isinstance(x, int)
validator=lambda x: isinstance(x, int),
expected_type=int
)
default_overwrite_switch = get_config_item_or_set_default(
key='default_overwrite_switch',
default_value=-1,
validator=lambda x: isinstance(x, int)
validator=lambda x: isinstance(x, int),
expected_type=int
)
example_inpaint_prompts = get_config_item_or_set_default(
key='example_inpaint_prompts',
default_value=[
'highly detailed face', 'detailed girl face', 'detailed man face', 'detailed hand', 'beautiful eyes'
],
validator=lambda x: isinstance(x, list) and all(isinstance(v, str) for v in x)
validator=lambda x: isinstance(x, list) and all(isinstance(v, str) for v in x),
expected_type=list
)
default_black_out_nsfw = get_config_item_or_set_default(
key='default_black_out_nsfw',
default_value=False,
validator=lambda x: isinstance(x, bool)
validator=lambda x: isinstance(x, bool),
expected_type=bool
)
default_save_metadata_to_images = get_config_item_or_set_default(
key='default_save_metadata_to_images',
default_value=False,
validator=lambda x: isinstance(x, bool)
validator=lambda x: isinstance(x, bool),
expected_type=bool
)
default_metadata_scheme = get_config_item_or_set_default(
key='default_metadata_scheme',
default_value=MetadataScheme.FOOOCUS.value,
validator=lambda x: x in [y[1] for y in modules.flags.metadata_scheme if y[1] == x]
validator=lambda x: x in [y[1] for y in modules.flags.metadata_scheme if y[1] == x],
expected_type=str
)
metadata_created_by = get_config_item_or_set_default(
key='metadata_created_by',
default_value='',
validator=lambda x: isinstance(x, str)
validator=lambda x: isinstance(x, str),
expected_type=str
)

example_inpaint_prompts = [[x] for x in example_inpaint_prompts]
Expand Down
15 changes: 15 additions & 0 deletions modules/extra_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import os
from ast import literal_eval


def makedirs_with_log(path):
try:
Expand All @@ -24,3 +26,16 @@ def get_files_from_folder(folder_path, extensions=None, name_filter=None):
filenames.append(path)

return filenames


def try_eval_env_var(value: str, expected_type=None):
try:
value_eval = value
if expected_type is bool:
value_eval = value.title()
value_eval = literal_eval(value_eval)
if expected_type is not None and not isinstance(value_eval, expected_type):
return value
return value_eval
except:
return value
Loading

0 comments on commit 5abae22

Please sign in to comment.