Skip to content

Commit 0465f4b

Browse files
committed
🔀 [AAAI|Merge] branch 'feature/use-preprocess-target'
2 parents cf23952 + 1f5ea93 commit 0465f4b

File tree

4 files changed

+39
-31
lines changed

4 files changed

+39
-31
lines changed

‎yolo/aaai.py‎

Lines changed: 16 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,12 @@ def __init__(self, cfg: Config, model):
3131
self.contrastive_loss = NT_Xent
3232
self.cfg = cfg
3333
self.metric = meanBoxCoverScore()
34+
self.target_source_train = torch.load(
35+
f"{cfg.dataset.path}/target/from_{cfg.task.data.target_source}_train.pt", weights_only=False
36+
)
37+
self.target_source_val = torch.load(
38+
f"{cfg.dataset.path}/target/from_{cfg.task.data.target_source}_val.pt", weights_only=False
39+
)
3440

3541
def set_task(self, task):
3642
self.task = task
@@ -42,12 +48,14 @@ def setup(self, stage):
4248
)
4349
self.loss_fn = AAAILoss(self.cfg.task.loss, self.vec2box)
4450
self.post_process = PostProcess(self.vec2box, self.cfg.task.validation.nms, aaai=True)
51+
self.target_source_train = self.target_source_train.to(self.device)
52+
self.target_source_val = self.target_source_val.to(self.device)
4553

4654
def forward(self, x, external=None, shortcut=None):
4755
return self.model(x, external, shortcut)
4856

4957
def training_step(self, batch, batch_idx):
50-
images_batch, masks_batch, puzzles_batch, idx_batch = batch
58+
images_batch, masks_batch, _, idx_batch = batch
5159
R_loss, C_loss, D_loss = 0, 0, 0
5260
images, bbox = images_batch
5361
batch_size = images.size(0)
@@ -61,18 +69,9 @@ def training_step(self, batch, batch_idx):
6169
R_loss_msk = self.construct_loss(masked_outputs["RMAP"] * masks, images * masks)
6270
R_loss = R_loss_all + R_loss_msk * 10
6371

64-
if self.task == "puzzle" or self.task == "detect":
65-
puzzle_images, origin_idx, puzzle_idx = puzzles_batch
66-
puzzle_outputs = self(puzzle_images, shortcut="FMAP")
67-
picked_puzzle = puzzle_outputs["FMAP"].detach()[batch_step[:, None], :, puzzle_idx[:, 0], puzzle_idx[:, 1]]
68-
69-
if self.task == "puzzle":
70-
picked_origin = origin_outputs["FMAP"][:, :, origin_idx[0, 0], origin_idx[0, 1]]
71-
C_loss = self.contrastive_loss(picked_puzzle, picked_origin)
72-
7372
if self.task == "detect":
74-
_, pick_idx = idx_batch
75-
picked_vector = picked_puzzle.view(batch_size, -1, 512)[batch_step, pick_idx]
73+
image_idx, pick_idx = idx_batch
74+
picked_vector = self.target_source_train[image_idx[:, None], pick_idx]
7675

7776
origin_outputs = self(images, dict(target=picked_vector.permute(0, 2, 1)))
7877
detections = self.vec2box(origin_outputs["Main"])
@@ -132,12 +131,9 @@ def validation_step(self, batch, batch_idx):
132131

133132
if self.task == "detect":
134133

135-
puzzle_images, origin_idx, puzzle_idx = puzzles_batch
136-
puzzle_outputs = self(puzzle_images, shortcut="FMAP")
137-
picked_puzzle = puzzle_outputs["FMAP"].detach()[batch_step[:, None], :, puzzle_idx[:, 0], puzzle_idx[:, 1]]
138-
139-
_, pick_idx = idx_batch
140-
picked_vector = picked_puzzle.view(batch_size, -1, 512)[batch_step, pick_idx]
134+
puzzle_images, origin_idx, puzzle_idx, puzzles = puzzles_batch
135+
image_idx, pick_idx = idx_batch
136+
picked_vector = self.target_source_val[image_idx[:, None], pick_idx]
141137

142138
origin_outputs = self(images, dict(target=picked_vector.permute(0, 2, 1)))
143139
H, W = images.shape[2:]
@@ -162,16 +158,13 @@ def validation_step(self, batch, batch_idx):
162158
def on_validation_batch_end(self, outputs, batch, batch_idx):
163159
if batch_idx != 0:
164160
return
165-
images_batch, _, puzzles_batch, _ = batch
161+
images_batch, _, _, _ = batch
166162
images, bbox = images_batch
167-
puzzle_images, _, _ = puzzles_batch
168163
origin_image = draw_bboxes(images[0], bbox[0])
169164
predict_image = draw_bboxes(images[0], outputs[0])
170-
puzzle_image = to_pil_image(puzzle_images[0])
171165
for logger in self.loggers:
172166
if isinstance(logger, WandbLogger):
173167
logger.log_image(f"Origin Image", [origin_image], self.current_epoch)
174-
logger.log_image(f"Puzzle Visualize", [puzzle_image], self.current_epoch)
175168
logger.log_image(f"Predict Visualize", [predict_image], self.current_epoch)
176169

177170
def on_validation_start(self):
@@ -236,7 +229,7 @@ def main(cfg: Config):
236229
deterministic=True,
237230
logger=loggers,
238231
devices=[0],
239-
gradient_clip_algorithm='norm',
232+
gradient_clip_algorithm="norm",
240233
gradient_clip_val=10,
241234
callbacks=[checkpoint_callback, RichProgressBar(), YOLORichModelSummary()],
242235
accelerator="auto",

‎yolo/config/config.py‎

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ class DataConfig:
6666
class OriginConfig:
6767
scale_origin: float
6868
background_color: int
69-
target_file: str
69+
target_postfix: str
7070

7171

7272
@dataclass
@@ -87,6 +87,7 @@ class AAAIDataConfig(DataConfig):
8787
mask: MaskConfig
8888
puzzle: PuzzleConfig
8989
num_target: int
90+
target_source: str
9091

9192

9293
@dataclass

‎yolo/config/task/aaai.yaml‎

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ target: detect
44
defaults:
55
- validation: ../validation
66

7-
epoch: 500
7+
epoch: 25
88

99
validation:
1010
nms:
@@ -22,6 +22,7 @@ data:
2222
resolution: 32
2323
size: 160
2424
num_target: 3
25+
target_source: clip
2526
batch_size: 16
2627
image_size: ${image_size}
2728
cpu_num: ${cpu_num}

‎yolo/tools/data_loader.py‎

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -385,6 +385,14 @@ def __len__(self):
385385
return len(self.image_paths)
386386

387387
def __getitem__(self, idx: int):
388+
"""
389+
Returns:
390+
- origin_image [3, H, W], bboxes [Box, num_target]
391+
- masked_image [3, H, W], mask [3, H, W]
392+
- puzzle_image [3, H, W], origin_grid [2, p_num, p_num], \
393+
puzzle_grid [2, p_num, p_num], puzzles [3, p_num, p_num, p_size, p_size]
394+
- idx [int], pick_idx [p_num]
395+
"""
388396
image_path = self.dataset_path / self.image_paths[idx]
389397
image = Image.open(image_path).convert("RGB")
390398
if self.transform:
@@ -398,10 +406,15 @@ def __getitem__(self, idx: int):
398406
if hasattr(self, "filter_box"):
399407
bboxes = bboxes[pick_idx]
400408
image = to_tensor(image)
401-
masked_image, mask = self.mask(image)
402-
puzzle_image, origin_grid, puzzle_grid = self.puzzle(image, shift_hw)
403-
404-
return (origin_image, bboxes), (masked_image, mask), (puzzle_image, origin_grid, puzzle_grid), (idx, pick_idx)
409+
# masked_image, mask = self.mask(image)
410+
# puzzle_image, origin_grid, puzzle_grid, puzzles = self.puzzle(image, shift_hw)
411+
412+
return (
413+
(origin_image, bboxes),
414+
([], []),
415+
([], [], [], []),
416+
(idx, pick_idx),
417+
)
405418

406419
def augment_origin(self, image: Image, shift_hw) -> tuple[Tensor, Tensor]:
407420
w, h = self.image_size[0] // self.puzzle_size, self.image_size[1] // self.puzzle_size
@@ -416,8 +429,8 @@ def augment_origin(self, image: Image, shift_hw) -> tuple[Tensor, Tensor]:
416429
.view(-1, 4)
417430
.float()
418431
)
419-
if self.task == "detect":
420-
image, bboxes = random_resize_crop(image, bboxes, self.data_cfg.main)
432+
# if self.task == "detect":
433+
# image, bboxes = random_resize_crop(image, bboxes, self.data_cfg.main)
421434
return to_tensor(image), bboxes
422435

423436
def puzzle(self, image: Tensor, hw: Optional[Tensor] = None) -> tuple[Tensor]:

0 commit comments

Comments
 (0)