@@ -31,6 +31,12 @@ def __init__(self, cfg: Config, model):
3131 self .contrastive_loss = NT_Xent
3232 self .cfg = cfg
3333 self .metric = meanBoxCoverScore ()
34+ self .target_source_train = torch .load (
35+ f"{ cfg .dataset .path } /target/from_{ cfg .task .data .target_source } _train.pt" , weights_only = False
36+ )
37+ self .target_source_val = torch .load (
38+ f"{ cfg .dataset .path } /target/from_{ cfg .task .data .target_source } _val.pt" , weights_only = False
39+ )
3440
3541 def set_task (self , task ):
3642 self .task = task
@@ -42,12 +48,14 @@ def setup(self, stage):
4248 )
4349 self .loss_fn = AAAILoss (self .cfg .task .loss , self .vec2box )
4450 self .post_process = PostProcess (self .vec2box , self .cfg .task .validation .nms , aaai = True )
51+ self .target_source_train = self .target_source_train .to (self .device )
52+ self .target_source_val = self .target_source_val .to (self .device )
4553
4654 def forward (self , x , external = None , shortcut = None ):
4755 return self .model (x , external , shortcut )
4856
4957 def training_step (self , batch , batch_idx ):
50- images_batch , masks_batch , puzzles_batch , idx_batch = batch
58+ images_batch , masks_batch , _ , idx_batch = batch
5159 R_loss , C_loss , D_loss = 0 , 0 , 0
5260 images , bbox = images_batch
5361 batch_size = images .size (0 )
@@ -61,18 +69,9 @@ def training_step(self, batch, batch_idx):
6169 R_loss_msk = self .construct_loss (masked_outputs ["RMAP" ] * masks , images * masks )
6270 R_loss = R_loss_all + R_loss_msk * 10
6371
64- if self .task == "puzzle" or self .task == "detect" :
65- puzzle_images , origin_idx , puzzle_idx = puzzles_batch
66- puzzle_outputs = self (puzzle_images , shortcut = "FMAP" )
67- picked_puzzle = puzzle_outputs ["FMAP" ].detach ()[batch_step [:, None ], :, puzzle_idx [:, 0 ], puzzle_idx [:, 1 ]]
68-
69- if self .task == "puzzle" :
70- picked_origin = origin_outputs ["FMAP" ][:, :, origin_idx [0 , 0 ], origin_idx [0 , 1 ]]
71- C_loss = self .contrastive_loss (picked_puzzle , picked_origin )
72-
7372 if self .task == "detect" :
74- _ , pick_idx = idx_batch
75- picked_vector = picked_puzzle . view ( batch_size , - 1 , 512 )[ batch_step , pick_idx ]
73+ image_idx , pick_idx = idx_batch
74+ picked_vector = self . target_source_train [ image_idx [:, None ] , pick_idx ]
7675
7776 origin_outputs = self (images , dict (target = picked_vector .permute (0 , 2 , 1 )))
7877 detections = self .vec2box (origin_outputs ["Main" ])
@@ -132,12 +131,9 @@ def validation_step(self, batch, batch_idx):
132131
133132 if self .task == "detect" :
134133
135- puzzle_images , origin_idx , puzzle_idx = puzzles_batch
136- puzzle_outputs = self (puzzle_images , shortcut = "FMAP" )
137- picked_puzzle = puzzle_outputs ["FMAP" ].detach ()[batch_step [:, None ], :, puzzle_idx [:, 0 ], puzzle_idx [:, 1 ]]
138-
139- _ , pick_idx = idx_batch
140- picked_vector = picked_puzzle .view (batch_size , - 1 , 512 )[batch_step , pick_idx ]
134+ puzzle_images , origin_idx , puzzle_idx , puzzles = puzzles_batch
135+ image_idx , pick_idx = idx_batch
136+ picked_vector = self .target_source_val [image_idx [:, None ], pick_idx ]
141137
142138 origin_outputs = self (images , dict (target = picked_vector .permute (0 , 2 , 1 )))
143139 H , W = images .shape [2 :]
@@ -162,16 +158,13 @@ def validation_step(self, batch, batch_idx):
162158 def on_validation_batch_end (self , outputs , batch , batch_idx ):
163159 if batch_idx != 0 :
164160 return
165- images_batch , _ , puzzles_batch , _ = batch
161+ images_batch , _ , _ , _ = batch
166162 images , bbox = images_batch
167- puzzle_images , _ , _ = puzzles_batch
168163 origin_image = draw_bboxes (images [0 ], bbox [0 ])
169164 predict_image = draw_bboxes (images [0 ], outputs [0 ])
170- puzzle_image = to_pil_image (puzzle_images [0 ])
171165 for logger in self .loggers :
172166 if isinstance (logger , WandbLogger ):
173167 logger .log_image (f"Origin Image" , [origin_image ], self .current_epoch )
174- logger .log_image (f"Puzzle Visualize" , [puzzle_image ], self .current_epoch )
175168 logger .log_image (f"Predict Visualize" , [predict_image ], self .current_epoch )
176169
177170 def on_validation_start (self ):
@@ -236,7 +229,7 @@ def main(cfg: Config):
236229 deterministic = True ,
237230 logger = loggers ,
238231 devices = [0 ],
239- gradient_clip_algorithm = ' norm' ,
232+ gradient_clip_algorithm = " norm" ,
240233 gradient_clip_val = 10 ,
241234 callbacks = [checkpoint_callback , RichProgressBar (), YOLORichModelSummary ()],
242235 accelerator = "auto" ,
0 commit comments