@@ -80,6 +80,7 @@ def __init__(self, ann_file, data_root, img_prefix, pipeline, classes):
80
80
self .CLASSES = classes # classes를 CLASSES 변수에 할당
81
81
82
82
super (AihubDataset , self ).__init__ (ann_file , data_root , img_prefix , pipeline )
83
+
83
84
def load_annotations (self , ann_file ):
84
85
print ('##### self.data_root:' , self .data_root , 'self.ann_file:' , self .ann_file , 'self.img_prefix:' , self .img_prefix )
85
86
print ('#### ann_file:' , ann_file )
@@ -124,7 +125,22 @@ def load_annotations(self, ann_file):
124
125
return data_infos
125
126
126
127
def train ():
127
- datasets = [build_dataset (cfg .data .train )]
128
+ # config에서 train 데이터셋 정보 가져오기
129
+ train_dataset = copy .deepcopy (cfg .data .train )
130
+ train_dataset .pipeline = cfg .train_pipeline
131
+
132
+ # 데이터셋 빌드를 위한 설정 추가
133
+ dataset_info = dict (
134
+ type = cfg .dataset_type ,
135
+ data_root = cfg .data_root ,
136
+ ann_file = cfg .data .train .ann_file , # ann_file 추가
137
+ img_prefix = train_dataset .img_prefix ,
138
+ classes = cfg .model .roi_head .bbox_head .num_classes ,
139
+ pipeline = train_dataset .pipeline
140
+ )
141
+
142
+ datasets = [build_dataset (dataset_info )]
143
+
128
144
model = build_detector (cfg .model , train_cfg = cfg .get ('train_cfg' ), test_cfg = cfg .get ('test_cfg' ))
129
145
model .CLASSES = datasets [0 ].CLASSES
130
146
0 commit comments