Skip to content

A network that estimates the relative distribution within a domain in unsupervised domain transfer tasks.

Notifications You must be signed in to change notification settings

MichelleAppel/importance_net

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

6 Commits
 
 
 
 
 
 

Repository files navigation

Importance Net

ImportanceNet is a network that estimates the relative distribution of a domain for unsupervised domain transfer tasks.

Image-to-image translation methods faithfully transform a source image to the style of a target domain, enabling a breadth of applications in domain adaptation and arts. Existing approaches focus on image quality, little attention is given to the distribution of generated images, which by default follows that of the source domain. We propose a method to re-sample a set of generated images to match the distribution of the target domain. At the core is a NN-module for estimating the relative frequency of image constellations in the source and target domains by matching modes of features.

To run the example we move to the example folder:

cd cycleGAN_example

To train the cycleGAN, keeping importanceNet freezed, we can run:

python3 train.py --train_GAN

Next to train importanceNet using the pretrained cycleGAN we run:

python3 train.py --continue_train --train_W

The folder importance_net in the cycleGAN example contains the network that performs the distribution estimation. To run within your own training loop simply add the pieces of code that are surrounded with ############# in train.py. In particular, the following are the only pieces of code to add into your desired training code in order to integrate importanceNet.

The imports:

from importance_net.models import create_model, importance_model, network
from importance_net.options.train_options import TrainOptions as ImportanceTrainOptions

importance_opt = ImportanceTrainOptions.parse()

Initialize the network:

objective_function = network.DiscriminatorLoss(model=model, gan_mode=opt.gan_mode)
importance_model = importance_model.ImportanceModel(importance_opt, objective_function.criterion)
importance_model.setup(importance_opt) # setup using options

And run within your training loop:

importance_model.set_input(data)         # unpack data from dataset and apply preprocessing
importance_model.optimize_parameters()   # calculate loss functions, get gradients, update network weights

About

A network that estimates the relative distribution within a domain in unsupervised domain transfer tasks.

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published