-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit bf20f72
Showing
65 changed files
with
7,467 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
# Yet another Deep Danbooru project | ||
But based on [RegNetY-8G](https://arxiv.org/abs/2003.13678), relative lightweight, designed to run fast on GPU. \ | ||
Training is done using mixed precision training on a single RTX2080Ti for 3 weeks. \ | ||
Some code are from https://github.com/facebookresearch/pycls | ||
# What do I need? | ||
You need to download [save_4000000.ckpt]() from release and place on the same folder as `test.py`. | ||
# How to use? | ||
`python test.py --model save_4000000.ckpt --image <PATH_TO_IMAGE>` | ||
# What to do in the future? | ||
1. Quantize to 8 bit | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
|
||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
|
||
from RegNetY_8G import build_model | ||
|
||
class RegDanbooru2019(nn.Module) : | ||
def __init__(self) : | ||
super(RegDanbooru2019, self).__init__() | ||
self.backbone = build_model() | ||
num_p = sum(p.numel() for p in self.backbone.parameters() if p.requires_grad) | ||
print( 'Backbone has %d parameters' % num_p ) | ||
self.head_danbooru = nn.Linear(2016, 4096) | ||
|
||
def forward_train_head(self, images) : | ||
""" | ||
images of shape [N, 3, 512, 512] | ||
""" | ||
with torch.no_grad() : | ||
feats = self.backbone(images) | ||
feats = F.adaptive_avg_pool2d(feats, 1).view(-1, 2016) | ||
danbooru_logits = self.head_danbooru(feats) # [N, 4096] | ||
return danbooru_logits | ||
|
||
def forward(self, images) : | ||
""" | ||
images of shape [N, 3, 512, 512] | ||
""" | ||
feats = self.backbone(images) | ||
feats = F.adaptive_avg_pool2d(feats, 1).view(-1, 2016) | ||
danbooru_logits = self.head_danbooru(feats) # [N, 4096] | ||
return danbooru_logits |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
MODEL: | ||
TYPE: regnet | ||
NUM_CLASSES: 1000 | ||
REGNET: | ||
SE_ON: true | ||
DEPTH: 17 | ||
W0: 192 | ||
WA: 76.82 | ||
WM: 2.19 | ||
GROUP_W: 56 | ||
OPTIM: | ||
LR_POLICY: cos | ||
BASE_LR: 0.4 | ||
MAX_EPOCH: 100 | ||
MOMENTUM: 0.9 | ||
WEIGHT_DECAY: 5e-5 | ||
WARMUP_EPOCHS: 5 | ||
TRAIN: | ||
DATASET: imagenet | ||
IM_SIZE: 512 | ||
BATCH_SIZE: 512 | ||
TEST: | ||
DATASET: imagenet | ||
IM_SIZE: 512 | ||
BATCH_SIZE: 400 | ||
NUM_GPUS: 1 | ||
OUT_DIR: . |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
#!/usr/bin/env python3 | ||
|
||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
"""Test a trained classification model.""" | ||
|
||
import argparse | ||
import sys | ||
|
||
import numpy as np | ||
import pycls.core.losses as losses | ||
import pycls.core.model_builder as model_builder | ||
import pycls.datasets.loader as loader | ||
import pycls.utils.benchmark as bu | ||
import pycls.utils.checkpoint as cu | ||
import pycls.utils.distributed as du | ||
import pycls.utils.logging as lu | ||
import pycls.utils.metrics as mu | ||
import pycls.utils.multiprocessing as mpu | ||
import pycls.utils.net as nu | ||
import torch | ||
from pycls.core.config import assert_and_infer_cfg, cfg | ||
from pycls.utils.meters import TestMeter | ||
|
||
|
||
logger = lu.get_logger(__name__) | ||
|
||
def log_model_info(model): | ||
"""Logs model info""" | ||
logger.info("Model:\n{}".format(model)) | ||
logger.info("Params: {:,}".format(mu.params_count(model))) | ||
logger.info("Flops: {:,}".format(mu.flops_count(model))) | ||
logger.info("Acts: {:,}".format(mu.acts_count(model))) | ||
|
||
def build_model(): | ||
|
||
# Load config options | ||
cfg.merge_from_file('RegNetY-8.0GF_dds_8gpu.yaml') | ||
cfg.merge_from_list([]) | ||
assert_and_infer_cfg() | ||
cfg.freeze() | ||
# Setup logging | ||
lu.setup_logging() | ||
# Show the config | ||
logger.info("Config:\n{}".format(cfg)) | ||
|
||
# Fix the RNG seeds (see RNG comment in core/config.py for discussion) | ||
np.random.seed(cfg.RNG_SEED) | ||
torch.manual_seed(cfg.RNG_SEED) | ||
# Configure the CUDNN backend | ||
torch.backends.cudnn.benchmark = cfg.CUDNN.BENCHMARK | ||
|
||
# Build the model (before the loaders to speed up debugging) | ||
model = model_builder.build_model() | ||
log_model_info(model) | ||
|
||
# Load model weights | ||
#cu.load_checkpoint('RegNetY-8.0GF_dds_8gpu.pyth', model) | ||
logger.info("Loaded model weights from: {}".format('RegNetY-8.0GF_dds_8gpu.pyth')) | ||
|
||
del model.head | ||
|
||
return model | ||
|
Oops, something went wrong.