Skip to content

Commit

Permalink
Added 'cuda' functionality as well as streamlined
Browse files Browse the repository at this point in the history
the dataloading process for segmentation tasks.
  • Loading branch information
ishan121028 committed Oct 8, 2023
1 parent f4a63d2 commit f036fc3
Show file tree
Hide file tree
Showing 8 changed files with 26 additions and 56 deletions.
23 changes: 0 additions & 23 deletions UniTrain.egg-info/PKG-INFO

This file was deleted.

10 changes: 0 additions & 10 deletions UniTrain.egg-info/SOURCES.txt

This file was deleted.

1 change: 0 additions & 1 deletion UniTrain.egg-info/dependency_links.txt

This file was deleted.

5 changes: 0 additions & 5 deletions UniTrain.egg-info/requires.txt

This file was deleted.

1 change: 0 additions & 1 deletion UniTrain.egg-info/top_level.txt

This file was deleted.

14 changes: 10 additions & 4 deletions UniTrain/utils/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,8 @@ def get_data_loader(data_dir, batch_size, shuffle=True, transform = None, split=

# Create a custom dataset
dataset = ClassificationDataset(data_dir, transform=transform)
print(dataset.__len__())

# Create a data loader
print(batch_size)
data_loader = DataLoader(
dataset,
batch_size=batch_size,
Expand Down Expand Up @@ -86,7 +84,7 @@ def parse_folder(dataset_path):
print("An error occurred:", str(e))
return None

def train_model(model, train_data_loader, test_data_loader, num_epochs, learning_rate=0.001, checkpoint_dir='checkpoints', logger=None, device='cpu'):
def train_model(model, train_data_loader, test_data_loader, num_epochs, learning_rate=0.001, checkpoint_dir='checkpoints', logger=None, device=torch.device('cpu')):
'''Train a PyTorch model for a classification task.
Args:
model (nn.Module): Torch model to train.
Expand All @@ -96,11 +94,17 @@ def train_model(model, train_data_loader, test_data_loader, num_epochs, learning
learning_rate (float): Learning rate for the optimizer.
checkpoint_dir (str): Directory to save model checkpoints.
logger (Logger): Logger to log training details.
device (torch.device): Device to run training on (GPU or CPU).
Returns:
None
'''

if logger:
logging.basicConfig(level=logging.INFO, format='%(asctime)s - Epoch %(epoch)d - Train Acc: %(train_acc).4f - Val Acc: %(val_acc).4f - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S', filename=logger, filemode='w')
logger = logging.getLogger(__name__)

# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
Expand All @@ -115,6 +119,9 @@ def train_model(model, train_data_loader, test_data_loader, num_epochs, learning
for batch_idx, (inputs, labels) in enumerate(train_data_loader):
optimizer.zero_grad() # Zero the parameter gradients

inputs = inputs.to(device)
labels = labels.to(device)

# Forward pass
outputs = model(inputs)
loss = criterion(outputs, labels)
Expand All @@ -138,7 +145,6 @@ def train_model(model, train_data_loader, test_data_loader, num_epochs, learning
logger.info(f'Epoch {epoch + 1}, Validation Accuracy: {accuracy:.2f}%')



if accuracy > best_accuracy:
best_accuracy = accuracy
checkpoint_path = os.path.join(checkpoint_dir, f'model_epoch_{epoch + 1}.pth')
Expand Down
26 changes: 15 additions & 11 deletions UniTrain/utils/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@
from torchvision import transforms
from ..dataset.segmentation import SegmentationDataset
import torchsummary
import logging
import glob

def get_data_loader(image_paths:list, mask_paths:list, batch_size:int, shuffle:bool=True, transform=None) -> DataLoader:
"""
def get_data_loader(data_dir: str, batch_size:int, shuffle:bool=True, transform=None, split='train') -> DataLoader:
""",
Create and return a data loader for a custom dataset.
Args:
Expand All @@ -28,6 +30,9 @@ def get_data_loader(image_paths:list, mask_paths:list, batch_size:int, shuffle:b
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) # Normalize with ImageNet stats
])

image_paths = glob.glob(os.path.join(data_dir, split, 'images', '*'))
mask_paths = glob.glob(os.path.join(data_dir, split, 'masks', '*'))

# Create a custom dataset
dataset = SegmentationDataset(image_paths=image_paths, mask_paths=mask_paths, transform=transform)

Expand Down Expand Up @@ -83,7 +88,7 @@ def parse_folder(dataset_path):
return False


def train_unet(model, train_data_loader, test_data_loader, num_epochs, learning_rate, checkpoint_dir, logger=None, iou=False, device='cpu'):
def train_unet(model, train_data_loader, test_data_loader, num_epochs, learning_rate, checkpoint_dir, logger=None, iou=False, device=torch.device('cpu')) -> None:
'''Train the model using the given train and test data loaders.
Args:
Expand All @@ -95,17 +100,17 @@ def train_unet(model, train_data_loader, test_data_loader, num_epochs, learning_
checkpoint_dir (str): Directory to save model checkpoints.
logger (Logger): Logger to log training information.
iou (bool): Whether to calculate the IOU score.
device (torch.device): Device to run training on (GPU or CPU).
Returns:
None
'''
if device == 'cpu':
device = torch.device('cpu')
elif device == 'cuda':
device = torch.device('cuda')
else:
print(f"{device} is not a valid device.")
return None

if logger:
logging.basicConfig(level=logging.INFO, format='%(asctime)s - Epoch %(epoch)d - Train Acc: %(train_acc).4f - Val Acc: %(val_acc).4f - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S', filename=logger, filemode='w')
logger = logging.getLogger(__name__)


criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
Expand Down Expand Up @@ -182,7 +187,6 @@ def iou_score(output, target):
Returns:
float: The average IoU score.
'''

smooth = 1e-6
output = output.argmax(1)
intersection = (output & target).float().sum((1, 2))
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

setup(
name='UniTrain',
version='0.2.2',
version='0.2.3',
author='Ishan Upadhyay',
author_email='[email protected]',
description='A generalized training framework for Deep Learning Tasks',
Expand Down

0 comments on commit f036fc3

Please sign in to comment.