Skip to content

Commit 8ce9eff

Browse files
committed
🔀 [Merge] branch 'Lightning'
2 parents e53ff09 + 7f8235a commit 8ce9eff

20 files changed

+595
-608
lines changed

tests/test_utils/test_bounding_box_utils.py

Lines changed: 52 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -146,23 +146,64 @@ def test_anc2box_autoanchor(inference_v7_cfg: Config):
146146

147147

148148
def test_bbox_nms():
149-
cls_dist = tensor(
150-
[[[0.1, 0.7, 0.2], [0.6, 0.3, 0.1]], [[0.4, 0.4, 0.2], [0.5, 0.4, 0.1]]] # Example class distribution
149+
cls_dist = torch.tensor(
150+
[
151+
[
152+
[0.7, 0.1, 0.2], # High confidence, class 0
153+
[0.3, 0.6, 0.1], # High confidence, class 1
154+
[-3.0, -2.0, -1.0], # low confidence, class 2
155+
[0.6, 0.2, 0.2], # Medium confidence, class 0
156+
],
157+
[
158+
[0.55, 0.25, 0.2], # Medium confidence, class 0
159+
[-4.0, -0.5, -2.0], # low confidence, class 1
160+
[0.15, 0.2, 0.65], # Medium confidence, class 2
161+
[0.8, 0.1, 0.1], # High confidence, class 0
162+
],
163+
],
164+
dtype=float32,
151165
)
152-
bbox = tensor(
153-
[[[50, 50, 100, 100], [60, 60, 110, 110]], [[40, 40, 90, 90], [70, 70, 120, 120]]], # Example bounding boxes
166+
167+
bbox = torch.tensor(
168+
[
169+
[
170+
[0, 0, 160, 120], # Overlaps with box 4
171+
[160, 120, 320, 240],
172+
[0, 120, 160, 240],
173+
[16, 12, 176, 132],
174+
],
175+
[
176+
[0, 0, 160, 120], # Overlaps with box 4
177+
[160, 120, 320, 240],
178+
[0, 120, 160, 240],
179+
[16, 12, 176, 132],
180+
],
181+
],
154182
dtype=float32,
155183
)
184+
156185
nms_cfg = NMSConfig(min_confidence=0.5, min_iou=0.5)
157186

158-
expected_output = [
159-
tensor(
187+
# Batch 1:
188+
# - box 1 is kept with class 0 as it has a higher confidence than box 4 i.e. box 4 is filtered out
189+
# - box 2 is kept with class 1
190+
# - box 3 is rejected by the confidence filter
191+
# Batch 2:
192+
# - box 4 is kept with class 0 as it has a higher confidence than box 1 i.e. box 1 is filtered out
193+
# - box 2 is rejected by the confidence filter
194+
# - box 3 is kept with class 2
195+
expected_output = torch.tensor(
196+
[
160197
[
161-
[1.0000, 50.0000, 50.0000, 100.0000, 100.0000, 0.6682],
162-
[0.0000, 60.0000, 60.0000, 110.0000, 110.0000, 0.6457],
163-
]
164-
)
165-
]
198+
[0.0, 0.0, 0.0, 160.0, 120.0, 0.6682],
199+
[1.0, 160.0, 120.0, 320.0, 240.0, 0.6457],
200+
],
201+
[
202+
[0.0, 16.0, 12.0, 176.0, 132.0, 0.6900],
203+
[2.0, 0.0, 120.0, 160.0, 240.0, 0.6570],
204+
],
205+
]
206+
)
166207

167208
output = bbox_nms(cls_dist, bbox, nms_cfg)
168209

yolo/__init__.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,29 +2,32 @@
22
from yolo.model.yolo import create_model
33
from yolo.tools.data_loader import AugmentationComposer, create_dataloader
44
from yolo.tools.drawer import draw_bboxes
5-
from yolo.tools.solver import ModelTester, ModelTrainer, ModelValidator
5+
from yolo.tools.solver import TrainModel
66
from yolo.utils.bounding_box_utils import Anc2Box, Vec2Box, bbox_nms, create_converter
77
from yolo.utils.deploy_utils import FastModelLoader
8-
from yolo.utils.logging_utils import ProgressLogger, custom_logger
8+
from yolo.utils.logging_utils import (
9+
ImageLogger,
10+
YOLORichModelSummary,
11+
YOLORichProgressBar,
12+
)
913
from yolo.utils.model_utils import PostProccess
1014

1115
all = [
1216
"create_model",
1317
"Config",
14-
"ProgressLogger",
18+
"YOLORichProgressBar",
1519
"NMSConfig",
16-
"custom_logger",
20+
"YOLORichModelSummary",
1721
"validate_log_directory",
1822
"draw_bboxes",
1923
"Vec2Box",
2024
"Anc2Box",
2125
"bbox_nms",
2226
"create_converter",
2327
"AugmentationComposer",
28+
"ImageLogger",
2429
"create_dataloader",
2530
"FastModelLoader",
26-
"ModelTester",
27-
"ModelTrainer",
28-
"ModelValidator",
31+
"TrainModel",
2932
"PostProccess",
3033
]

yolo/config/general.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ out_path: runs
77
exist_ok: True
88

99
lucky_number: 10
10-
use_wandb: False
10+
use_wandb: True
1111
use_tensorboard: False
1212

1313
weight: True # Path to weight or True for auto, False for no pretrained weight

yolo/config/task/inference.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,4 @@ data:
88
nms:
99
min_confidence: 0.5
1010
min_iou: 0.5
11-
# save_predict: True
11+
save_predict: True

yolo/config/task/validation.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,5 +8,5 @@ data:
88
pin_memory: True
99
data_augment: {}
1010
nms:
11-
min_confidence: 0.05
12-
min_iou: 0.9
11+
min_confidence: 0.0001
12+
min_iou: 0.7

yolo/lazy.py

Lines changed: 26 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -2,41 +2,41 @@
22
from pathlib import Path
33

44
import hydra
5+
from lightning import Trainer
56

67
project_root = Path(__file__).resolve().parent.parent
78
sys.path.append(str(project_root))
89

910
from yolo.config.config import Config
10-
from yolo.model.yolo import create_model
11-
from yolo.tools.data_loader import create_dataloader
12-
from yolo.tools.solver import ModelTester, ModelTrainer, ModelValidator
13-
from yolo.utils.bounding_box_utils import create_converter
14-
from yolo.utils.deploy_utils import FastModelLoader
15-
from yolo.utils.logging_utils import ProgressLogger
16-
from yolo.utils.model_utils import get_device
11+
from yolo.tools.solver import InferenceModel, TrainModel, ValidateModel
12+
from yolo.utils.logging_utils import setup
1713

1814

1915
@hydra.main(config_path="config", config_name="config", version_base=None)
2016
def main(cfg: Config):
21-
progress = ProgressLogger(cfg, exp_name=cfg.name)
22-
device, use_ddp = get_device(cfg.device)
23-
dataloader = create_dataloader(cfg.task.data, cfg.dataset, cfg.task.task, use_ddp)
24-
if getattr(cfg.task, "fast_inference", False):
25-
model = FastModelLoader(cfg).load_model(device)
26-
else:
27-
model = create_model(cfg.model, class_num=cfg.dataset.class_num, weight_path=cfg.weight)
28-
model = model.to(device)
29-
30-
converter = create_converter(cfg.model.name, model, cfg.model.anchor, cfg.image_size, device)
31-
32-
if cfg.task.task == "train":
33-
solver = ModelTrainer(cfg, model, converter, progress, device, use_ddp)
34-
if cfg.task.task == "validation":
35-
solver = ModelValidator(cfg.task, cfg.dataset, model, converter, progress, device)
36-
if cfg.task.task == "inference":
37-
solver = ModelTester(cfg, model, converter, progress, device)
38-
progress.start()
39-
solver.solve(dataloader)
17+
callbacks, loggers = setup(cfg)
18+
19+
trainer = Trainer(
20+
accelerator="cuda",
21+
max_epochs=getattr(cfg.task, "epoch", None),
22+
precision="16-mixed",
23+
callbacks=callbacks,
24+
logger=loggers,
25+
log_every_n_steps=1,
26+
gradient_clip_val=10,
27+
deterministic=True,
28+
)
29+
30+
match cfg.task.task:
31+
case "train":
32+
model = TrainModel(cfg)
33+
trainer.fit(model)
34+
case "validation":
35+
model = ValidateModel(cfg)
36+
trainer.validate(model)
37+
case "inference":
38+
model = InferenceModel(cfg)
39+
trainer.predict(model)
4040

4141

4242
if __name__ == "__main__":

yolo/model/module.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,10 @@
33
import torch
44
import torch.nn.functional as F
55
from einops import rearrange
6-
from loguru import logger
76
from torch import Tensor, nn
87
from torch.nn.common_types import _size_2_t
98

9+
from yolo.utils.logger import logger
1010
from yolo.utils.module_utils import auto_pad, create_activation_function, round_up
1111

1212

yolo/model/yolo.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,12 @@
33
from typing import Dict, List, Union
44

55
import torch
6-
from loguru import logger
76
from omegaconf import ListConfig, OmegaConf
87
from torch import nn
98

109
from yolo.config.config import ModelConfig, YOLOLayer
1110
from yolo.tools.dataset_preparation import prepare_weight
11+
from yolo.utils.logger import logger
1212
from yolo.utils.module_utils import get_layer_map
1313

1414

@@ -32,10 +32,10 @@ def __init__(self, model_cfg: ModelConfig, class_num: int = 80):
3232
def build_model(self, model_arch: Dict[str, List[Dict[str, Dict[str, Dict]]]]):
3333
self.layer_index = {}
3434
output_dim, layer_idx = [3], 1
35-
logger.info(f"🚜 Building YOLO")
35+
logger.info(f":tractor: Building YOLO")
3636
for arch_name in model_arch:
3737
if model_arch[arch_name]:
38-
logger.info(f" 🏗️ Building {arch_name}")
38+
logger.info(f" :building_construction: Building {arch_name}")
3939
for layer_idx, layer_spec in enumerate(model_arch[arch_name], start=layer_idx):
4040
layer_type, layer_info = next(iter(layer_spec.items()))
4141
layer_args = layer_info.get("args", {})
@@ -123,7 +123,7 @@ def save_load_weights(self, weights: Union[Path, OrderedDict]):
123123
weights: A OrderedDict containing the new weights.
124124
"""
125125
if isinstance(weights, Path):
126-
weights = torch.load(weights, map_location=torch.device("cpu"))
126+
weights = torch.load(weights, map_location=torch.device("cpu"), weights_only=False)
127127
if "model_state_dict" in weights:
128128
weights = weights["model_state_dict"]
129129

@@ -144,7 +144,7 @@ def save_load_weights(self, weights: Union[Path, OrderedDict]):
144144

145145
for error_name, error_set in error_dict.items():
146146
for weight_name in error_set:
147-
logger.warning(f"⚠️ Weight {error_name} for key: {'.'.join(weight_name)}")
147+
logger.warning(f":warning: Weight {error_name} for key: {'.'.join(weight_name)}")
148148

149149
self.model.load_state_dict(model_state_dict)
150150

@@ -171,7 +171,7 @@ def create_model(model_cfg: ModelConfig, weight_path: Union[bool, Path] = True,
171171
prepare_weight(weight_path=weight_path)
172172
if weight_path.exists():
173173
model.save_load_weights(weight_path)
174-
logger.info(" Success load model & weight")
174+
logger.info(":white_check_mark: Success load model & weight")
175175
else:
176-
logger.info(" Success load model")
176+
logger.info(":white_check_mark: Success load model")
177177
return model

0 commit comments

Comments
 (0)