diff --git a/utils/tool.py b/utils/tool.py index fd4ee4f..77a79e2 100644 --- a/utils/tool.py +++ b/utils/tool.py @@ -8,8 +8,8 @@ def __init__(self, path): with open(path, encoding='utf8') as f: data = yaml.load(f, Loader=yaml.FullLoader) - self.val_txt = data["DATASET"]["VAL"] - self.train_txt = data["DATASET"]["TRAIN"] + self.val_txt = self.getImgPath_toOneFile(data["DATASET"]["VAL"],'val') + self.train_txt = self.getImgPath_toOneFile(data["DATASET"]["TRAIN"],'train') self.names = data["DATASET"]["NAMES"] self.learn_rate = data["TRAIN"]["LR"] @@ -23,6 +23,26 @@ def __init__(self, path): self.category_num = data["MODEL"]["NC"] print("Load yaml sucess...") + + def getImgPath_toOneFile(self,yaml_path:str,name_type:str): + """ + generate txt file : train.txt or val.txt + """ + if yaml_path.endswith('txt'): + return yaml_path + + pre_path = os.path.join(yaml_path,'../') + txt_path = os.path.join(pre_path,name_type+'.txt') + f_txt = open(txt_path,'w') + img_path_list = os.listdir(yaml_path) + imgTypeList = ['jpg','png','bmp'] + for content in img_path_list: + if content.split('.')[-1] not in imgTypeList:continue + img_full_path = os.path.join(yaml_path,content) + f_txt.write(img_full_path + '\n') + pass + f_txt.close() + return txt_path class EMA(): def __init__(self, model, decay):