From ee99b4b17bd74a3ca392bb63fc16891c69348735 Mon Sep 17 00:00:00 2001 From: Shannon Shen <22512825+lolipopshock@users.noreply.github.com> Date: Mon, 14 Feb 2022 14:38:24 -0800 Subject: [PATCH] Add lp model weights loading in training --- tools/train_net.py | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/tools/train_net.py b/tools/train_net.py index ba6ab23..d2ed086 100644 --- a/tools/train_net.py +++ b/tools/train_net.py @@ -213,9 +213,34 @@ def main(args): "--image_path_val", help="The path to the validation set image folder", ) + parser.add_argument("--lp_model", default="", help="The name of the layoutparser model") + args = parser.parse_args() print("Command Line Args:", args) + if args.lp_model is not None: + + try: + from layoutparser.models.detectron2.catalog import PathManager + except ImportError: + print("Please install the latest version of layoutparser to use LP Model weights for fine-tuning.") + print("\t pip install layoutparser") + exit() + + assert args.lp_model.startswith("lp://"), "Please use Detectron2 models from https://layout-parser.github.io/platform/" + + model_path = PathManager.get_local_path(args.lp_model.rstrip("/weight") + "/weight") + config_path = PathManager.get_local_path(args.lp_model.rstrip("/config") + "/config") + + if "MODEL.WEIGHTS" in args.opts: + idx = args.opts.index("MODEL.WEIGHTS") + args.opts[idx+1] = model_path + else: + args.opts.extend(["MODEL.WEIGHTS", model_path]) + + args.config_file = config_path + + # Dataset Registration is moved to the main function to support multi-gpu training # See ref https://github.com/facebookresearch/detectron2/issues/253#issuecomment-554216517