Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Models can be logged as W&B artifacts. #511

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions basicsr/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,12 @@ def update_learning_rate(self, current_iter, warmup_iter=-1):
def get_current_learning_rate(self):
return [param_group['lr'] for param_group in self.optimizers[0].param_groups]

def get_save_path(self, net_label, current_iter):
if current_iter == -1:
current_iter = 'latest'
save_filename = f'{net_label}_{current_iter}.pth'
return os.path.join(self.opt['path']['models'], save_filename)

@master_only
def save_network(self, net, net_label, current_iter, param_key='params'):
"""Save networks.
Expand All @@ -203,10 +209,7 @@ def save_network(self, net, net_label, current_iter, param_key='params'):
param_key (str | list[str]): The parameter key(s) to save network.
Default: 'params'.
"""
if current_iter == -1:
current_iter = 'latest'
save_filename = f'{net_label}_{current_iter}.pth'
save_path = os.path.join(self.opt['path']['models'], save_filename)
save_path = self.get_save_path(net_label, current_iter)

net = net if isinstance(net, list) else [net]
param_key = param_key if isinstance(param_key, list) else [param_key]
Expand Down
6 changes: 3 additions & 3 deletions basicsr/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,13 @@
from basicsr.data.prefetch_dataloader import CPUPrefetcher, CUDAPrefetcher
from basicsr.models import build_model
from basicsr.utils import (AvgTimer, MessageLogger, check_resume, get_env_info, get_root_logger, get_time_str,
init_tb_logger, init_wandb_logger, make_exp_dirs, mkdir_and_rename, scandir)
init_tb_logger, wandb_enabled, init_wandb_logger, log_artifact, make_exp_dirs, mkdir_and_rename, scandir)
from basicsr.utils.options import copy_opt_file, dict2str, parse_options


def init_tb_loggers(opt):
# initialize wandb logger before tensorboard logger to allow proper sync
if (opt['logger'].get('wandb') is not None) and (opt['logger']['wandb'].get('project')
is not None) and ('debug' not in opt['name']):
if wandb_enabled(opt):
assert opt['logger'].get('use_tb_logger') is True, ('should turn on tensorboard when using wandb')
init_wandb_logger(opt)
tb_logger = None
Expand Down Expand Up @@ -184,6 +183,7 @@ def train_pipeline(root_path):
if current_iter % opt['logger']['save_checkpoint_freq'] == 0:
logger.info('Saving models and training states.')
model.save(epoch, current_iter)
log_artifact(opt, model, 'net_g', current_iter)

# validation
if opt.get('val') is not None and (current_iter % opt['val']['val_freq'] == 0):
Expand Down
4 changes: 3 additions & 1 deletion basicsr/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from .file_client import FileClient
from .img_process_util import USMSharp, usm_sharp
from .img_util import crop_border, imfrombytes, img2tensor, imwrite, tensor2img
from .logger import AvgTimer, MessageLogger, get_env_info, get_root_logger, init_tb_logger, init_wandb_logger
from .logger import AvgTimer, MessageLogger, get_env_info, get_root_logger, init_tb_logger, init_wandb_logger, wandb_enabled, log_artifact
from .misc import check_resume, get_time_str, make_exp_dirs, mkdir_and_rename, scandir, set_random_seed, sizeof_fmt

__all__ = [
Expand All @@ -19,6 +19,8 @@
'AvgTimer',
'init_tb_logger',
'init_wandb_logger',
'wandb_enabled',
'log_artifact',
'get_root_logger',
'get_env_info',
# misc.py
Expand Down
26 changes: 26 additions & 0 deletions basicsr/utils/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,32 @@ def init_wandb_logger(opt):
logger.info(f'Use wandb logger with id={wandb_id}; project={project}.')


def wandb_enabled(opt):
if 'debug' in opt['name']: return False
if opt['logger'].get('wandb') is None: return False
if opt['logger']['wandb'].get('project') is None: return False
return True


@master_only
def log_artifact(opt, model, net_label, current_step):
if not wandb_enabled(opt):
return

if not opt['logger']['wandb'].get('log_model', False):
return

import wandb

# Prepend run id to artifact name so it is attributed to the current run
name = opt['name']
name = f'{wandb.run.id}_{name}'
artifact = wandb.Artifact(name, type='model')
save_path = model.get_save_path(net_label, current_step)
artifact.add_file(save_path)
wandb.run.log_artifact(artifact)


def get_root_logger(logger_name='basicsr', log_level=logging.INFO, log_file=None):
"""Get the root logger.

Expand Down