diff --git a/Segmentation/train.py b/Segmentation/train.py index 3039b7b..0c8c8bc 100644 --- a/Segmentation/train.py +++ b/Segmentation/train.py @@ -23,6 +23,8 @@ def main(): parser.add_argument('--logging_directory', type=str, default='logs', help='Directory for logging') parser.add_argument('--checkpoint_directory', type=str, default='checkpoints', help='Directory for saving checkpoints') parser.add_argument('--classes', type=int, default='2', help='No. of classes you want to segment your model into.') + parser.add_argument('--iou', type=bool, default=False, help='Enable or disable IoU') + parser.add_argument('--device', type=str, default='cpu', help='Device to train on') args = parser.parse_args() # Create the logging directory @@ -94,7 +96,9 @@ def main(): num_epochs=args.epochs, learning_rate=args.learning_rate, checkpoint_dir=args.checkpoint_directory, - logger=logging + logger=logging, + iou=args.iou, + device=args.device ) diff --git a/Segmentation/utils.py b/Segmentation/utils.py index b9d547b..ac685fd 100644 --- a/Segmentation/utils.py +++ b/Segmentation/utils.py @@ -75,7 +75,15 @@ def parse_folder(dataset_path): return False -def train_unet(model, train_data_loader, test_data_loader, num_epochs, learning_rate, checkpoint_dir, logger=None): +def train_unet(model, train_data_loader, test_data_loader, num_epochs, learning_rate, checkpoint_dir, logger=None, iou=False, device='cpu'): + if device == 'cpu': + device = torch.device('cpu') + elif device == 'cuda': + device = torch.device('cuda') + else: + print(f"{device} is not a valid device.") + return None + criterion = nn.CrossEntropyLoss() optimizer = optim.Adam(model.parameters(), lr=learning_rate) @@ -84,36 +92,49 @@ def train_unet(model, train_data_loader, test_data_loader, num_epochs, learning_ for epoch in range(num_epochs): model.train() train_loss = 0.0 + iou_score_mean = 0.0 for inputs, targets in tqdm(train_data_loader, desc=f'Epoch {epoch + 1}/{num_epochs}', leave=False): optimizer.zero_grad() outputs = model(inputs) targets = targets.squeeze(1) + outputs.to(device) + targets.to(device) loss = criterion(outputs, targets) loss.backward() optimizer.step() train_loss += loss.item() + iou_score_mean += iou_score(outputs, targets) + iou_score_mean = iou_score_mean / len(train_data_loader) average_train_loss = train_loss / len(train_data_loader) - if logger: - logger.info(f'Epoch {epoch + 1}/{num_epochs}, Train Loss: {average_train_loss:.4f}') + if logger and iou: + logger.info(f'Epoch {epoch + 1}/{num_epochs}, Train Loss: {average_train_loss:.4f}, IOU Score: {iou_score_mean:.4f}') + else: + logger.info(f"Epoch {epoch + 1}/{num_epochs}, Train Loss: {average_train_loss:.4f}") # Validation model.eval() val_loss = 0.0 - + iou_score_mean = 0.0 with torch.no_grad(): for inputs, targets in tqdm(test_data_loader, desc=f'Validation', leave=False): outputs = model(inputs) targets = targets.squeeze(1) + outputs.to(device) + targets.to(device) loss = criterion(outputs, targets) val_loss += loss.item() + iou_score_mean += iou_score(outputs, targets) + iou_score_mean = iou_score_mean / len(test_data_loader) average_val_loss = val_loss / len(test_data_loader) - if logger: - logger.info(f'Epoch {epoch + 1}/{num_epochs}, Validation Loss: {average_val_loss:.4f}') + if logger and iou: + logger.info(f'Epoch {epoch + 1}/{num_epochs}, Validation Loss: {average_val_loss:.4f}. IOU Score: {iou_score_mean:.4f}') + else: + logger.info(f"Epoch {epoch + 1}/{num_epochs}, Validation Loss: {average_val_loss:.4f}") # Save model checkpoint if validation loss improves if average_val_loss < best_loss: @@ -126,4 +147,12 @@ def train_unet(model, train_data_loader, test_data_loader, num_epochs, learning_ print('Finished Training') def generate_model_summary(model, input_size): - torchsummary.summary(model, input_size=input_size) \ No newline at end of file + torchsummary.summary(model, input_size=input_size) + +def iou_score(output, target): + smooth = 1e-6 + output = output.argmax(1) + intersection = (output & target).float().sum((1, 2)) + union = (output | target).float().sum((1, 2)) + iou = (intersection + smooth) / (union + smooth) + return iou.mean() \ No newline at end of file