Skip to content

dslisleedh/GinConfigExample

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

25 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Simple MNIST Classification Project Using Gin Config and TF2


This is a simple MNIST classification project using Gin Config and Tensorflow2

conda env create -f environment.yaml
conda activate gin
python train.py

You can easily change model by overiding model argument

python train.py model=mlpmixer

Hyperparameters are determined by these configs.

  • ./conf/models/[model_name].gin (Model selection and hyperparameters)
  • ./conf/optimizer/config.gin (Optimizer selection and hyperparameters)
  • ./conf/others/config.gin (Other train-related hyperparameters. ex) loss_fn, batch_size, epochs, ...)

Config file example ...

# ./conf/models/mlpmixer_config.gin
model_config.model = @MLPMixer()

MLPMixer.config_intro = {
    'n_filters' : 128,
    'patch_size' : 4
}
MLPMixer.config_feature_extractor = {
    'n_layers': 8,
    'dropout_rate': .2,
    'act': @tf.nn.gelu,
    'expansion_rate': 4
}
MLPMixer.config_classifier = {
    'n_filters': (),
    'act': @tf.nn.relu,
    'dropout_rate': 0.5,
    'n_classes': 10
}

Implemented models model_call_name

  • Classic MLP simple_mlp
  • VGGNet vggnet
  • ResNet resnet
  • MLP-Mixer mlpmixer

About

Simple MNIST classification Project Using Gin-Config

Topics

Resources

Stars

Watchers

Forks

Languages