diff --git a/train_classification.py b/train_classification.py index a41f0a8..b014258 100644 --- a/train_classification.py +++ b/train_classification.py @@ -44,7 +44,7 @@ def main(args): print_info_message('Loading pretrained basenet model weights') model_dict = model.state_dict() - overlap_dict = {k: v for k, v in model_dict.items() if k in pretrained_dict} + overlap_dict = {k: v for k, v in pretrained_dict if k in model_dict.items()} total_size_overlap = 0 for k, v in enumerate(overlap_dict):