Skip to content

Commit c6718f1

Browse files
committed
update
1 parent ef958d4 commit c6718f1

File tree

6 files changed

+25
-11
lines changed

6 files changed

+25
-11
lines changed

conf/config.yaml

+4
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@ defaults:
55
- scheduler: step
66
- dataset: imagenet
77
- pipeline: classifier
8+
- hydra/job_logging: colorlog
9+
- hydra/hydra_logging: colorlog
810

911
pl_trainer:
1012
accelerator: ddp
@@ -36,6 +38,8 @@ callbacks:
3638
monitor: val_acc1
3739
save_top_k: 2
3840
verbose: true
41+
- _target_: pytorch_lightning.callbacks.LearningRateMonitor
42+
logging_interval: 'step'
3943

4044
run_test: true
4145
seed: 2021

conf/optim/adam.yaml

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# @package _group_
2+
_target_: torch.optim.Adam
3+
lr: 0.001

conf/optim/adamw.yaml

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# @package _group_
2+
_target_: torch.optim.AdamW
3+
lr: 0.001

conf/scheduler/cosine.yaml

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# @package _group_
2+
_target_: torch.optim.lr_scheduler.CosineAnnealingLR
3+
T_max: 90

docker/requirements.txt

+1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
ipython
22
pytorch-lightning==1.2.5
33
hydra-core==1.0.6
4+
hydra_colorlog
45
omegaconf==2.0.6

pipeline/classifier.py

+11-11
Original file line numberDiff line numberDiff line change
@@ -32,13 +32,6 @@ def __init__(
3232
def forward(self, x):
3333
return self.model(x)
3434

35-
def _log_metrics(self):
36-
if self.trainer.is_global_zero:
37-
str_metrics = ''
38-
for key, val in self.trainer.logged_metrics.items():
39-
str_metrics += f'\n\t{key}: {val}'
40-
logger.info(str_metrics)
41-
4235
def training_step(self, batch, batch_idx):
4336
images, target = batch
4437
output = self(images)
@@ -48,7 +41,7 @@ def training_step(self, batch, batch_idx):
4841
self.log('train_acc1', acc1, on_step=True, prog_bar=True, on_epoch=True)
4942
self.log('train_acc5', acc5, on_step=True, on_epoch=True)
5043
return loss_train
51-
44+
5245
def validation_step(self, batch, batch_idx):
5346
images, target = batch
5447
output = self(images)
@@ -59,9 +52,6 @@ def validation_step(self, batch, batch_idx):
5952
self.log('val_acc5', acc5, on_epoch=True)
6053
return loss_val
6154

62-
def on_validation_end(self):
63-
self._log_metrics()
64-
6555
@staticmethod
6656
def __accuracy(output, target, topk=(1, )):
6757
"""Computes the accuracy over the k top predictions for the specified values of k"""
@@ -89,6 +79,16 @@ def test_step(self, batch, batch_idx):
8979
self.log('test_acc5', acc5, on_epoch=True)
9080
return loss_test
9181

82+
def _log_metrics(self):
83+
if self.trainer.is_global_zero:
84+
str_metrics = ''
85+
for key, val in self.trainer.logged_metrics.items():
86+
str_metrics += f'\n\t{key}: {val}'
87+
logger.info(str_metrics)
88+
89+
def on_validation_end(self):
90+
self._log_metrics()
91+
9292
def on_test_end(self):
9393
self._log_metrics()
9494

0 commit comments

Comments
 (0)