Skip to content

Commit

Permalink
Refactor test_prepare_checkpoints
Browse files Browse the repository at this point in the history
  • Loading branch information
daniil-lyakhov committed Jul 6, 2021
1 parent e669857 commit 26ed43c
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 32 deletions.
3 changes: 3 additions & 0 deletions examples/tensorflow/common/prepare_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ def get_config_and_model_type_from_argv(argv, parser):
raise RuntimeError('Wrong model type specified')

predefined_config.update(config_from_json)
if not predefined_config.ckpt_path:
raise RuntimeError('Checkpoint path should be specified')

return predefined_config, args.model_type


Expand Down
48 changes: 16 additions & 32 deletions tests/tensorflow/test_sanity_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import os
import tempfile
from functools import partial
from pathlib import Path
import pytest
import tensorflow as tf

Expand Down Expand Up @@ -366,60 +365,45 @@ def test_export_with_resume(_config, tmp_path, export_format, _case_common_dirs)
assert os.path.exists(model_path)


def get_prepare_checkpoint_configs():
supported_model_types = ['object_detection', 'segmentation']
config_params = []
for sample_type in supported_model_types:
config_paths, batch_sizes = CONFIGS[sample_type], GLOBAL_BATCH_SIZE[sample_type]
dataset_names, dataset_types = zip(*DATASETS[sample_type])

for config_path, dataset_name, dataset_type, batch_size in \
zip(config_paths, dataset_names, dataset_types, batch_sizes):
dataset_path = DATASET_PATHS[sample_type][dataset_name](None)
PREPARE_CHECKPOINTS_SUPPORTED_SAMPLE_TYPES = ['object_detection', 'segmentation']

with config_path.open() as f:
jconfig = json.load(f)

if 'checkpoint_save_dir' in jconfig.keys():
del jconfig['checkpoint_save_dir']

jconfig['dataset'] = dataset_name
jconfig['dataset_type'] = dataset_type
config_params.append((sample_type, config_path, jconfig, dataset_path, batch_size))
return config_params

@pytest.mark.dependency(depends=['tf_test_model_train'])
def test_prepare_checkpoint(_config, tmp_path, _case_common_dirs):
if _config['sample_type'] not in PREPARE_CHECKPOINTS_SUPPORTED_SAMPLE_TYPES:
pytest.skip('Unsupported sample type for test_prepare_checkpoints')

@pytest.mark.parametrize('sample_type,config_path,config_eval,dataset_path,batch_size',
get_prepare_checkpoint_configs(),
ids=[x[0] for x in get_prepare_checkpoint_configs()])
def test_prepare_checkpoint(sample_type, config_path, config_eval, dataset_path, batch_size, tmp_path):
# Keep default soft_device_placement state
default_soft_device_placement = tf.config.get_soft_device_placement()
tf.config.set_soft_device_placement(True)
checkpoint_save_dir = tmp_path
log_dir = tempfile.mkdtemp()
resume_path = os.path.join(_case_common_dirs['checkpoint_save_dir'], _config['tid'])
config_factory = ConfigFactory(_config['nncf_config'], tmp_path / 'config.json')
args = {
'--model-type': sample_type,
'--config': config_path,
'--model-type': _config['sample_type'],
'--config': config_factory.serialize(),
'--checkpoint-save-dir': checkpoint_save_dir,
'--resume': tempfile.mkdtemp(),
'--resume': resume_path,
}

prepare_checkpoint_main(convert_to_argv(args))

assert tf.io.gfile.isdir(checkpoint_save_dir)
assert tf.train.latest_checkpoint(checkpoint_save_dir)
config_factory = ConfigFactory(config_eval, Path(tempfile.gettempdir()) / 'config.json')
args = {
'--mode': 'test',
'--data': dataset_path,
'--data': _config['dataset_path'],
'--config': config_factory.serialize(),
'--log-dir': log_dir,
'--batch-size': batch_size,
'--batch-size': _config['batch_size'],
'--resume': checkpoint_save_dir
}

main = get_sample_fn(sample_type, modes=['test'])
# TODO(nlyalyus): a WA for 58902 issue with matching layer indexes from builder state and from loaded model
tf.keras.backend.clear_session()

main = get_sample_fn(_config['sample_type'], modes=['test'])
main(convert_to_argv(args))
# Restore default soft_device_placement state
tf.config.set_soft_device_placement(default_soft_device_placement)

0 comments on commit 26ed43c

Please sign in to comment.