diff --git a/README.md b/README.md index 49a0a99..6d39ba7 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,5 @@ # StyleFlow: Attribute-conditioned Exploration of StyleGAN-Generated Images using Conditional Continuous Normalizing Flows (ACM TOG 2021) - +## See you @ Siggraph 2021 ![Python 3.7](https://img.shields.io/badge/Python-3.7-green.svg?style=plastic) ![pytorch 1.1.0](https://img.shields.io/badge/Pytorch-1.1.0-green.svg?style=plastic) ![TensorFlow 1.15.0](https://img.shields.io/badge/TensorFlow-1.15.0-green.svg?style=plastic) @@ -13,6 +13,7 @@ High-quality, diverse, and photorealistic images can now be generated by uncondi > **StyleFlow: Attribute-conditioned Exploration of StyleGAN-Generated Images using Conditional Continuous Normalizing Flows (ACM TOG 2021)**
> Rameen Abdal, Peihao Zhu, Niloy Mitra, Peter Wonka
+> KAUST, Adobe Research
@@ -22,18 +23,41 @@ High-quality, diverse, and photorealistic images can now be generated by uncondi [[Promotional Video](https://youtu.be/Lt4Z5oOAeEY)] +## Note: This repo works only in Windows 10 + + + ## Installation Clone this repo. ```bash -git clone https://github.com/RameenAbdal/StyleFlow.git +git clone https://github.com/justinjohn0306/StyleFlow.git cd StyleFlow/ ``` -This code requires PyTorch, TensorFlow, Torchdiffeq, Python 3+ and Pyqt5. Please install dependencies by -```bash -conda env create -f environment.yml -``` +This code requires PyTorch, TensorFlow, Torchdiffeq, Python 3+ and Pyqt5. + +Please install dependencies by following these instuctions properly: + +1. conda env create -f env_windows.yml (Download and Install Anaconda): [Download Anaconda 64-Bit Graphical Installer] + (https://www.anaconda.com/products/individual) + +2. conda activate styleflow + +3. conda install pytorch==1.2.0 torchvision==0.4.0 cudatoolkit=10.0 -c pytorch + [Also available here:(https://pytorch.org/get-started/previous-versions/)] + +4. Download and Install Microsoft Visual Studio Community 2017 : [https://visualstudio.microsoft.com/vs/older-downloads] + Note: MSV17 is specifically required. MSVS19 won't work + + Make sure to add Microsoft Visual Studio to Windows path + eg-: (C:\Program Files (x86)\Microsoft Visual Studio\2017\Community\VC\Auxiliary\Build\vcvars64.bat) + +5. add compiler_bindir_search_path inside custom_ops.py + + (inside this folder: dnnlib\tflib) + eg: ('C:/Program Files (x86)/Microsoft Visual Studio/2017/Community/VC/Tools/MSVC/14.16.27023/bin/HostX64/x64',) + StyleGAN2 relies on custom TensorFlow ops that are compiled on the fly using [NVCC](https://docs.nvidia.com/cuda/cuda-compiler-driver-nvcc/index.html). To correctly setup the StyleGAN2 generator follow the **Requirements** in [this repo](https://github.com/NVlabs/stylegan2). @@ -87,18 +111,43 @@ xhost -local:docker ``` +## Web UI (Beta) +A web based UI is also now available. The WebUI is based on [Streamlit](https://www.streamlit.io/) framework and is still in development phase. To get started, install streamlit from pip: +```bash +pip install streamlit +pip uninstall protobuf python3-protobuf +pip install --upgrade pip +pip install --upgrade protobuf +``` +Then run the streamlit app located under webui/ folder as follows: +```bash +cd webui +streamlit run app.py +``` +This should automatically open a new browser tab with the UI. + +![image](./docs/assets/styleflow-web-final.gif) + + ## Training New Model +Dataset containing sampled StyleGAN2 latents, lighting SH parameters and other attributes. ([Download Here](https://drive.google.com/file/d/1opdzeqpYWtE1uexO49JI-3_RWfE9MYlN/view?usp=sharing)) -To be added +Create `./data_numpy/` in the main folder and extract the above data or create your own dataset. +Train your model: +```bash + python train_flow.py +``` +## Projection +Our new projection method is currently under review. To be updated! ## License -All rights reserved. Licensed under the [CC BY-NC-SA 4.0](https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode) (**Attribution-NonCommercial-ShareAlike 4.0 International**) The code is released for academic research use only. +All rights reserved. Licensed under the [CC BY-NC-SA 4.0](https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode) (**Attribution-NonCommercial-ShareAlike 4.0 International**). The code is released for academic research use only. ## Citation -If you use this research/codebase, please cite our papers. +If you use this research/codebase/dataset, please cite our papers. ``` @article{abdal2020styleflow, title={Styleflow: Attribute-conditioned exploration of stylegan-generated images using conditional continuous normalizing flows}, @@ -119,5 +168,9 @@ If you use this research/codebase, please cite our papers. pages={4431-4440}, doi={10.1109/ICCV.2019.00453}} ``` + +## Broader Impact +*Important* : Deep learning based facial imagery like DeepFakes and GAN generated images can be gravely misused. This can spread misinformation and lead to other offences. The intent of our work is not to promote such practices but instead be used in the areas such as identification (novel views of a subject, occlusion inpainting etc. ), security (facial composites etc.), image compression (high quality video conferencing at lower bitrates etc.) and development of algorithms for detecting DeepFakes. + ## Acknowledgments This implementation builds upon the awesome work done by Karras et al. ([StyleGAN2](https://github.com/NVlabs/stylegan2)), Chen et al. ([torchdiffeq](https://github.com/rtqichen/torchdiffeq)) and Yang et al. ([PointFlow](https://arxiv.org/abs/1906.12320)). This work was supported by Adobe Research and KAUST Office of Sponsored Research (OSR). diff --git a/dnnlib/tflib/custom_ops.py b/dnnlib/tflib/custom_ops.py index f87c0d8..7687d7c 100644 --- a/dnnlib/tflib/custom_ops.py +++ b/dnnlib/tflib/custom_ops.py @@ -25,18 +25,25 @@ verbose = True # Print status messages to stdout. compiler_bindir_search_path = [ - 'C:/Program Files (x86)/Microsoft Visual Studio/2017/Community/VC/Tools/MSVC/14.14.26428/bin/Hostx64/x64', - 'C:/Program Files (x86)/Microsoft Visual Studio/2019/Community/VC/Tools/MSVC/14.23.28105/bin/Hostx64/x64', - 'C:/Program Files (x86)/Microsoft Visual Studio 14.0/vc/bin', + 'C:/Program Files (x86)/Microsoft Visual Studio/2017/Community/VC/Tools/MSVC/14.16.27023/bin/HostX64/x64', ] #---------------------------------------------------------------------------- # Internal helper funcs. def _find_compiler_bindir(): + + #Derive MSVC compiler path from compiler_bindir_search_path array for compiler_path in compiler_bindir_search_path: if os.path.isdir(compiler_path): return compiler_path + + #Derive MSVC compiler path from subdirectory tree + subdirectory_paths = [x[0] for x in os.walk('C:\\Program Files (x86)\\Microsoft Visual Studio\\')]; + if subdirectory_paths is not None: + for directory_path in subdirectory_paths: + if _compiler_path_validator(directory_path): + return directory_path return None def _get_compute_cap(device): @@ -61,9 +68,7 @@ def _run_cmd(cmd): raise RuntimeError('NVCC returned an error. See below for full command line and output log:\n\n%s\n\n%s' % (cmd, output)) def _prepare_nvcc_cli(opts): - # cmd = 'nvcc ' + opts.strip() - cmd = '/usr/local/cuda/bin/nvcc --std=c++11 -DNDEBUG ' + opts.strip() - + cmd = 'nvcc ' + opts.strip() cmd += ' --disable-warnings' cmd += ' --include-path "%s"' % tf.sysconfig.get_include() cmd += ' --include-path "%s"' % os.path.join(tf.sysconfig.get_include(), 'external', 'protobuf_archive', 'src') @@ -81,6 +86,16 @@ def _prepare_nvcc_cli(opts): cmd += ' 2>&1' return cmd +def _compiler_path_validator(path): + if path is not None: + if path[:76] == 'C:\\Program Files (x86)\\Microsoft Visual Studio\\2017\\Community\\VC\\Tools\\MSVC\\' and path[-16:] == '\\bin\\Hostx64\\x64': + return True + elif path[:76] == 'C:\\Program Files (x86)\\Microsoft Visual Studio\\2019\\Community\\VC\\Tools\\MSVC\\' and path[-16:] == '\\bin\\Hostx64\\x64': + return True + elif path == 'C:\\Program Files (x86)\\Microsoft Visual Studio 14.0\\vc\\bin': + return True + return False + #---------------------------------------------------------------------------- # Main entry point. diff --git a/docs/_layouts/default.html b/docs/_layouts/default.html index 705b17c..76517d2 100644 --- a/docs/_layouts/default.html +++ b/docs/_layouts/default.html @@ -22,6 +22,8 @@
StyleFlow : Attribute-conditioned Exploration of StyleGAN-generated Images using Conditional Continuous Normalizing Flows (ACM TOG 2021) +
+ See you @ Siggraph 2021
Rameen Abdal1  Peihao Zhu1  diff --git a/docs/assets/styleflow-web-final.gif b/docs/assets/styleflow-web-final.gif new file mode 100644 index 0000000..9e74567 Binary files /dev/null and b/docs/assets/styleflow-web-final.gif differ diff --git a/env_windows.yml b/env_windows.yml new file mode 100644 index 0000000..6f5abc5 --- /dev/null +++ b/env_windows.yml @@ -0,0 +1,108 @@ +name: styleflow +channels: + - pytorch + - anaconda + - defaults +dependencies: + - _pytorch_select=1.1.0=cpu + - _tflow_select=2.1.0=gpu + - absl-py=0.9.0=py36_0 + - astor=0.8.0=py36_0 + - blas=1.0=mkl + - ca-certificates=2020.12.8=haa95532_0 + - certifi=2020.12.5=py36haa95532_0 + - cffi=1.14.0=py36h7a1dbc1_0 + - cloudpickle=1.4.1=py_0 + - cycler=0.10.0=py36h009560c_0 + - cytoolz=0.10.1=py36he774522_0 + - dask-core=2.16.0=py_0 + - decorator=4.4.2=py_0 + - freetype=2.8=vc14h17c9bdf_0 + - gast=0.3.3=py_0 + - grpcio=1.27.2=py36h351948d_0 + - h5py=2.7.1=py36he54a1c3_0 + - hdf5=1.10.1=vc14hb361328_0 + - icc_rt=2019.0.0=h0cc432a_1 + - icu=58.2=vc14hc45fdbb_0 + - imageio=2.8.0=py_0 + - intel-openmp=2020.1=216 + - jpeg=9b=vc14h4d7706e_1 + - keras-applications=1.0.8=py_0 + - keras-preprocessing=1.1.0=py_1 + - kiwisolver=1.2.0=py36h74a9793_0 + - libpng=1.6.37=h2a8f88b_0 + - libprotobuf=3.11.4=h7bd577a_0 + - libtiff=4.1.0=h56a325e_0 + - markdown=3.1.1=py36_0 + - mkl=2020.1=216 + - mkl-service=1.1.2=py36hb217b18_5 + - mkl_fft=1.0.6=py36hdbbee80_0 + - mkl_random=1.0.1=py36h77b88f5_1 + - networkx=2.4=py_0 + - ninja=1.9.0=py36h74a9793_0 + - numpy=1.15.4=py36ha559c80_0 + - numpy-base=1.15.4=py36h8128ebf_0 + - olefile=0.46=py_0 + - openssl=1.1.1i=h2bbff1b_0 + - pip=20.0.2=py36_3 + - protobuf=3.11.4=py36h33f27b4_0 + - pycparser=2.20=py_0 + - pyparsing=2.4.7=py_0 + - pyqt=5.9.2=py36h6538335_2 + - python=3.6.7=h33f27b4_1 + - python-dateutil=2.8.1=py_0 + - pytz=2020.1=py_0 + - pywavelets=1.1.1=py36he774522_0 + - qt=5.9.7=vc14h73c81de_0 + - scikit-image=0.16.2=py36h47e9c7a_0 + - scipy=1.1.0=py36h4f6bf74_1 + - setuptools=46.4.0=py36_0 + - sip=4.19.8=py36h6538335_0 + - six=1.14.0=py36_0 + - sqlite=3.31.1=h2a8f88b_1 + - tensorboard=1.14.0=py36he3c9ec2_0 + - tensorflow=1.14.0=gpu_py36h305fd99_0 + - tensorflow-base=1.14.0=gpu_py36h55fc52a_0 + - tensorflow-estimator=1.14.0=py_0 + - tensorflow-gpu=1.14.0=h0d30ee6_0 + - termcolor=1.1.0=py36_1 + - tk=8.6.7=vc14hb68737d_1 + - toolz=0.10.0=py_0 + - tornado=6.0.4=py36he774522_1 + - vc=14.1=h0510ff6_4 + - vs2015_runtime=14.16.27012=hf0eaf9b_1 + - werkzeug=1.0.1=py_0 + - wheel=0.34.2=py36_0 + - wincertstore=0.2=py36h7fe50ca_0 + - wrapt=1.12.1=py36he774522_1 + - xz=5.2.5=h62dcd97_0 + - zlib=1.2.11=vc14h1cdd9ab_1 + - zstd=1.3.7=h508b16e_0 + - pip: + - chardet==3.0.4 + - dataclasses==0.8 + - deprecated==1.2.10 + - helpdev==0.7.1 + - idna==2.9 + - imagecodecs==2020.2.18 + - importlib-metadata==1.6.0 + - joblib==0.15.1 + - matplotlib==3.2.1 + - opencv-python==4.2.0.34 + - pillow==7.1.2 + - pyqt5==5.14.2 + - pyqt5-sip==12.7.2 + - qdarkgraystyle==1.0.2 + - qdarkstyle==2.8.1 + - qtpy==1.9.0 + - requests==2.23.0 + - scikit-learn==0.23.1 + - threadpoolctl==2.0.0 + - tifffile==2020.5.11 + - torch==1.7.1 + - torchdiffeq==0.0.1 + - tqdm==4.48.0 + - typing-extensions==3.7.4.3 + - urllib3==1.25.9 + - zipp==3.1.0 + diff --git a/flow_weight/modelsmall10k.pt b/flow_weight/modelsmall10k.pt new file mode 100644 index 0000000..67a2bb4 Binary files /dev/null and b/flow_weight/modelsmall10k.pt differ diff --git a/train_flow.py b/train_flow.py new file mode 100644 index 0000000..7c54d04 --- /dev/null +++ b/train_flow.py @@ -0,0 +1,107 @@ + +import dnnlib +from torch import nn, optim +import torch +import numpy as np +from torch.utils import data +from module.flow import cnf +from math import log, pi +import os +from tqdm import tqdm + +import random +import torchvision.transforms as transforms +from torch.utils.data import Dataset +import torchvision.datasets as dset +import argparse + + + +def standard_normal_logprob(z): + dim = z.size(-1) + log_z = -0.5 * dim * log(2 * pi) + return log_z - z.pow(2) / 2 + + +class MyDataset(Dataset): + def __init__(self, latents, attributes, transform=None): + self.latents = latents + self.attributes = attributes + self.transform = transform + + def __getitem__(self, index): + x = self.latents[index] + y = self.attributes[index] + + + if self.transform: + x = self.transform(x) + y = self.transform(y) + + return x, y + + def __len__(self): + return len(self.latents) + + +if __name__ == '__main__': + + parser = argparse.ArgumentParser(description="StyleFlow trainer") + + parser.add_argument("--latent_path",default='data_numpy/latents.npy', type=str, help="path to the latents") + parser.add_argument("--light_path",default='data_numpy/lighting.npy', type=str, help="path to the lighting parameters") + parser.add_argument("--attributes_path",default='data_numpy/attributes.npy', type=str, help="path to the attribute parameters") + parser.add_argument( + "--batch", type=int, default=5, help="batch size" + ) + parser.add_argument( + "--epochs", type=int, default=3, help="number of epochs" + ) + + parser.add_argument("--flow_modules", type=str, default='512-512-512-512-512') + parser.add_argument("--cond_size", type=int, default=17) + parser.add_argument("--lr", type=float, default=1e-3) + + + args = parser.parse_args() + torch.manual_seed(0) + + prior = cnf(512, args.flow_modules, args.cond_size, 1) + + sg_latents = np.load(args.latent_path) + lighting = np.load(args.light_path) + attributes = np.load(args.attributes_path) + sg_attributes = np.concatenate([lighting,attributes], axis = 1) + + my_dataset = MyDataset(latents=torch.Tensor(sg_latents).cuda(), attributes=torch.tensor(sg_attributes).float().cuda()) + train_loader = data.DataLoader(my_dataset, shuffle=False, batch_size=args.batch) + + optimizer = optim.Adam(prior.parameters(), lr=args.lr) + + + with tqdm(range(args.epochs)) as pbar: + for epoch in pbar: + for i, x in enumerate(train_loader): + + approx21, delta_log_p2 = prior(x[0].squeeze(1), x[1], torch.zeros(args.batch, x[0].shape[2], 1).to(x[0])) + + approx2 = standard_normal_logprob(approx21).view(args.batch, -1).sum(1, keepdim=True) + + + delta_log_p2 = delta_log_p2.view(args.batch, x[0].shape[2], 1).sum(1) + log_p2 = (approx2 - delta_log_p2) + + loss = -log_p2.mean() + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + pbar.set_description( + f'logP: {loss:.5f}') + + if i % 1000 == 0: + torch.save( + prior.state_dict(), f'trained_model/modellarge10k_{str(i).zfill(6)}_{str(epoch).zfill(2)}.pt' + ) + diff --git a/webui/app.py b/webui/app.py new file mode 100644 index 0000000..390a457 --- /dev/null +++ b/webui/app.py @@ -0,0 +1,312 @@ +# Author shariqfarooq123 + +import streamlit as st + +st.set_page_config( + layout="wide", # Can be "centered" or "wide" + initial_sidebar_state="expanded", # Can be "auto", "expanded", "collapsed" + page_title="StyleFlow web demo", # String or None. Strings get appended with "• Streamlit". + page_icon=None, # String, anything supported by st.image, or None. +) +import sys +sys.path.insert(0, "../") + +from options.test_options import TestOptions + +import numpy as np + +from utils import Build_model +import torch +import torch.nn +from module.flow import cnf +import os +import tensorflow as tf +import pickle +import copy + + +""" # Welcome to SyleFlow WebUI demo (Beta) +Go wild! +""" + + +# Currently TF runs on GPU and flow model (PyTorch) uses CPU!! - Should be fairly fast anyway +# TODO: Need to get around CUDA memory overflow bugs to enable flow inference on GPU + + + +DATA_ROOT = "../data" +HASH_FUNCS = {tf.Session: id, + torch.nn.Module: id, + Build_model: lambda _ : None, + torch.Tensor: lambda x: x.cpu().numpy()} + +# Select images +all_idx = np.array([2, 5, 25, 28, 16, 32, 33, 34, 55, 75, 79, 162, 177, 196, 160, 212, 246, 285, 300, 329, 362, + 369, 462, 460, 478, 551, 583, 643, 879, 852, 914, 999, 976, 627, 844, 237, 52, 301, + 599], dtype='int') + +EPS = 1e-3 # arbitrary positive value + + +class State: # Simple dirty hack for maintaining state + prev_attr = None + prev_idx = None + first = True + # ... and other state variables + +if not hasattr(st, 'data'): # Run only once. Save data globally + + st.state = State() + with st.spinner("Setting up... This might take a few minutes"): + raw_w = pickle.load(open(os.path.join(DATA_ROOT, "sg2latents.pickle"), "rb")) + # raw_TSNE = np.load(os.path.join(DATA_ROOT, 'TSNE.npy')) # We are picking images here by index instead + raw_attr = np.load(os.path.join(DATA_ROOT, 'attributes.npy')) + raw_lights = np.load(os.path.join(DATA_ROOT, 'light.npy')) + + all_w = np.array(raw_w['Latent'])[all_idx] + all_attr = raw_attr[all_idx] + all_lights = raw_lights[all_idx] + + light0 = torch.from_numpy(raw_lights[8]).float() + light1 = torch.from_numpy(raw_lights[33]).float() + light2 = torch.from_numpy(raw_lights[641]).float() + light3 = torch.from_numpy(raw_lights[547]).float() + light4 = torch.from_numpy(raw_lights[28]).float() + light5 = torch.from_numpy(raw_lights[34]).float() + + pre_lighting = [light0, light1, light2, light3, light4, light5] + + st.data = dict(raw_w=raw_w, all_w=all_w, all_attr=all_attr, all_lights=all_lights, + pre_lighting=pre_lighting) + + +def make_slider(name, min_value=0.0, max_value=1.0, step=0.1, **kwargs): + return st.sidebar.slider(name, min_value, max_value, step=step, **kwargs) + +@st.cache(allow_output_mutation=True, hash_funcs={dict: id}, show_spinner=False) +def get_idx2init(raw_w): + print(type(raw_w)) + idx2init = {i: np.array(raw_w['Latent'])[i] for i in all_idx} + return idx2init + +@st.cache(hash_funcs=HASH_FUNCS) +def init_model(): + # Open a new TensorFlow session. + config = tf.ConfigProto(allow_soft_placement=True) + session = tf.Session(config=config) + + opt = TestOptions().parse() + with session.as_default(): + model = Build_model(opt) + w_avg = model.Gs.get_var('dlatent_avg') + + prior = cnf(512, '512-512-512-512-512', 17, 1) + prior.load_state_dict(torch.load('../flow_weight/modellarge10k.pt')) + prior.eval() + + return session, model, w_avg, prior.cpu() + +@st.cache(allow_output_mutation=True, show_spinner=False, hash_funcs=HASH_FUNCS) +@torch.no_grad() +def flow_w_to_z(flow_model, w, attributes, lighting): + w_cuda = torch.Tensor(w) + att_cuda = torch.from_numpy(np.asarray(attributes)).float().unsqueeze(0).unsqueeze(-1).unsqueeze(-1) + light_cuda = torch.Tensor(lighting) + + features = torch.cat([light_cuda, att_cuda], dim=1).clone().detach() + zero_padding = torch.zeros(1, 18, 1) + z = flow_model(w_cuda, features, zero_padding)[0].clone().detach() + + return z + +@st.cache(allow_output_mutation=True, show_spinner=False, hash_funcs=HASH_FUNCS) +@torch.no_grad() +def flow_z_to_w(flow_model, z, attributes, lighting): + att_cuda = torch.Tensor(np.asarray(attributes)).float().unsqueeze(0).unsqueeze(-1).unsqueeze(-1) + light_cuda = torch.Tensor(lighting) + + features = torch.cat([light_cuda, att_cuda], dim=1).clone().detach() + zero_padding = torch.zeros(1, 18, 1) + w = flow_model(z, features, zero_padding, True)[0].clone().detach().numpy() + + return w + +@st.cache(show_spinner=False, hash_funcs=HASH_FUNCS) +@torch.no_grad() +def generate_image(session, model, w): + with session.as_default(): + img = model.generate_im_from_w_space(w)[0].copy() + return img + +def preserve_w_id(w_new, w_orig, attr_index): + # Ssssh! secret sauce to strip vectors + w_orig = torch.Tensor(w_orig) + if attr_index == 0: + w_new[0][8:] = w_orig[0][8:] + + elif attr_index == 1: + w_new[0][:2] = w_orig[0][:2] + w_new[0][4:] = w_orig[0][4:] + + elif attr_index == 2: + + w_new[0][4:] = w_orig[0][4:] + + elif attr_index == 3: + w_new[0][4:] = w_orig[0][4:] + + elif attr_index == 4: + w_new[0][6:] = w_orig[0][6:] + + elif attr_index == 5: + w_new[0][:5] = w_orig[0][:5] + w_new[0][10:] = w_orig[0][10:] + + elif attr_index == 6: + w_new[0][0:4] = w_orig[0][0:4] + w_new[0][8:] = w_orig[0][8:] + + elif attr_index == 7: + w_new[0][:4] = w_orig[0][:4] + w_new[0][6:] = w_orig[0][6:] + return w_new + + +def is_new_idx_set(idx): + if st.state.first: + st.state.first = False + st.state.prev_idx = idx + return True + + if idx != st.state.prev_idx: + st.state.prev_idx = idx + return True + return False + +def reset_state(idx): + st.state = State() + st.state.first = False + st.state.prev_idx = idx + +def np_copy(*args): # shortcut to clone multiple arrays + return [np.copy(arg) for arg in args] + +def get_changed_light(lights, light_names): + for i, name in enumerate(light_names): + change = abs(lights[name] - st.state.prev_lights[i]) + if change > EPS: + return i + return None + + + +def main(): + attribute_names = ['Gender', 'Glasses', 'Yaw', 'Pitch', 'Baldness', 'Beard', 'Age', 'Expression'] + attr_degree_list = [1.5, 2.5, 1., 1., 2, 1.7, 0.93, 1.] + + light_names = ['Left->Right', 'Right->Left', 'Down->Up', 'Up->Down', 'No light', 'Front light'] + + att_min = {'Gender': 0, 'Glasses': 0, 'Yaw': -20, 'Pitch': -20, 'Baldness': 0, 'Beard': 0.0, 'Age': 0, + 'Expression': 0} + att_max = {'Gender': 1, 'Glasses': 1, 'Yaw': 20, 'Pitch': 20, 'Baldness': 1, 'Beard': 1, 'Age': 65, 'Expression': 1} + + + with st.spinner("Setting up... This might take a few minutes... Please wait!"): + all_w, all_attr, all_lights = np_copy(st.data["all_w"], st.data["all_attr"], st.data["all_lights"]) + pre_lighting = list(st.data["pre_lighting"]) + idx2w_init = get_idx2init(st.data["raw_w"]) + session, model, w_avg, flow_model = init_model() + + idx_selected = st.selectbox("Choose an image:", list(range(len(idx2w_init))), + format_func= lambda opt : all_idx[opt]) + + w_selected = all_w[idx_selected] + attr_selected = all_attr[idx_selected].ravel() + lights_selected = all_lights[idx_selected] + z_selected = flow_w_to_z(flow_model, w_selected, attr_selected, lights_selected) + + if is_new_idx_set(idx_selected): + reset_state(idx_selected) + st.state.prev_attr = attr_selected.copy() + st.state.prev_lights = lights_selected.ravel().copy() + st.state.z_current = copy.deepcopy(z_selected) + st.state.w_current = torch.Tensor(w_selected) + st.state.w_prev = torch.Tensor(w_selected) + st.state.light_current = torch.Tensor(lights_selected).float() + + st.sidebar.markdown("# Attributes") + attributes = {} + for i, att in enumerate(attribute_names): + attributes[att] = make_slider(att, float(att_min[att]), float(att_max[att]), + value=float(attr_selected.ravel()[i]), # value on first render + key=hash(idx_selected*1e5 + i) # re-render if index selected is changed! + ) + + st.sidebar.markdown("# Lighting") + lights = {} + for i, lt in enumerate(light_names): + lights[lt] = make_slider(lt, + value=float(lights_selected.ravel()[i]), # value on first render + key=hash(idx_selected*1e6 + i) # re-render if index selected is changed! + ) + + img_source = generate_image(session, model, w_selected) + + att_new = list(attributes.values()) + + for i, att in enumerate(attribute_names): # Not the greatest code, but works! + attr_change = attributes[att] - st.state.prev_attr[i] + + if abs(attr_change) > EPS: + print(f"Changed attr {att} : {attr_change}") + attr_final = attr_degree_list[i] * attr_change + st.state.prev_attr[i] + att_new[i] = attr_final + print("\n") + + if hasattr(st.state, 'prev_changed') and st.state.prev_changed != att: + st.state.z_current = flow_w_to_z(flow_model, st.state.w_current, st.state.prev_attr_factored, lights_selected) + st.state.prev_attr[i] = attributes[att] + st.state.prev_changed = att + st.state.prev_attr_factored = att_new + st.state.w_current = flow_z_to_w(flow_model, st.state.z_current, att_new, lights_selected) + break # Streamlit re-runs on each interaction. Probably works but need to test for any bugs here + + pre_lighting_distance = [pre_lighting[i] - st.state.light_current for i in range(len(light_names))] + lights_magnitude = np.zeros(len(light_names)) + changed_light_index = get_changed_light(lights, light_names) + + if changed_light_index is not None: + lights_magnitude[changed_light_index] = lights[light_names[changed_light_index]] + + lighting_final = torch.Tensor(st.state.light_current) + for i in range(len(light_names)): + lighting_final += lights_magnitude[i] * pre_lighting_distance[i] + + w_current = flow_z_to_w(flow_model, st.state.z_current, att_new, lighting_final) + + w_current[0][0:7] = st.state.w_current[0][0:7] # some stripping + w_current[0][12:18] = st.state.w_current[0][12:18] + + st.state.w_current = w_current + lights_new = lighting_final + + st.state.prev_lights[changed_light_index] = lights[light_names[changed_light_index]] + else: + lights_new = lights_selected + + col1, col2 = st.beta_columns(2) # Columns feature of streamlit is still in beta. This line might require to be changed in future versions + with col1: + st.image(img_source, caption="Generated", use_column_width=True) + + with col2: + st.state.w_current = preserve_w_id(st.state.w_current, st.state.w_prev, i) + img_target = generate_image(session, model, st.state.w_current) + st.image(img_target, caption="Target", use_column_width=True) + + st.state.z_current = flow_w_to_z(flow_model, st.state.w_current, att_new, lights_new) + st.state.w_prev = torch.Tensor(st.state.w_current).clone().detach() + + +if __name__ == '__main__': + main()