Skip to content

Commit

Permalink
Resolved errors related to importing modules
Browse files Browse the repository at this point in the history
  • Loading branch information
ishan121028 committed Oct 7, 2023
1 parent b785d65 commit b40c386
Show file tree
Hide file tree
Showing 15 changed files with 69 additions and 9 deletions.
1 change: 1 addition & 0 deletions UniTrain.egg-info/SOURCES.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ LICENSE
README.md
setup.py
UniTrain/__init__.py
UniTrain/train.py
UniTrain.egg-info/PKG-INFO
UniTrain.egg-info/SOURCES.txt
UniTrain.egg-info/dependency_links.txt
Expand Down
1 change: 1 addition & 0 deletions UniTrain.egg-info/requires.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@ torch
torchvision
numpy
pandas
torchsummary
3 changes: 3 additions & 0 deletions UniTrain/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .dataset import *
from .models import *
from .utils import *
4 changes: 2 additions & 2 deletions UniTrain/dataset/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from classification import ClassificationDataset
from segmentation import SegmentationDataset
from .classification import ClassificationDataset
from .segmentation import SegmentationDataset
Binary file modified UniTrain/dataset/__pycache__/classification.cpython-311.pyc
Binary file not shown.
4 changes: 2 additions & 2 deletions UniTrain/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from classification import ResNet9
from segmentation import UNet
from .classification import ResNet9
from .segmentation import UNet
Binary file modified UniTrain/models/__pycache__/classification.cpython-311.pyc
Binary file not shown.
38 changes: 38 additions & 0 deletions UniTrain/train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from models.segmentation import UNet
from utils.segmentation import parse_folder, get_data_loader, train_unet
from torchvision import transforms
import glob


def main():
if parse_folder('data'):


# Make Your Custom Data Transformations
# transform = transforms.Compose([
# transforms.Resize((224, 224)), # Resize images to a fixed size
# transforms.ToTensor(), # Convert images to PyTorch tensors
# transforms.Normalize((0.485, 0.456, 0.406),
# (0.229, 0.224, 0.225)) # Normalize with ImageNet stats
# ])
train_image_paths = glob.glob("data/train/images/*.jpg")
train_mask_paths = glob.glob("data/train/masks/*.png")

test_image_paths = glob.glob("data/test/images/*.jpg")
test_mask_paths = glob.glob("data/test/masks/*.png")

print(train_image_paths, train_mask_paths, test_image_paths, test_mask_paths)

train_dataloader = get_data_loader(train_image_paths,train_mask_paths, 1, True)
test_dataloader = get_data_loader(test_image_paths, test_mask_paths, 1, True)

model = UNet(n_class=20)

train_unet(model, train_dataloader, test_dataloader, num_epochs=10, learning_rate=1e-3, checkpoint_dir='checkpoints')

else:
print("Invalid dataset folder.")
return None

if __name__ == '__main__':
main()
4 changes: 2 additions & 2 deletions UniTrain/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from classification import get_data_loader, parse_folder, train_model
from segmentation import get_data_loader, parse_folder, train_unet, generate_model_summary, get_iou_score
from .classification import get_data_loader, parse_folder, train_model
from .segmentation import get_data_loader, parse_folder, train_unet, generate_model_summary, iou_score
Binary file modified UniTrain/utils/__pycache__/classification.cpython-311.pyc
Binary file not shown.
Binary file modified UniTrain/utils/__pycache__/segmentation.cpython-311.pyc
Binary file not shown.
20 changes: 18 additions & 2 deletions UniTrain/utils/classification.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os
from torch.utils.data import DataLoader
from torchvision import transforms
from dataset.classification import ClassificationDataset
from ..dataset.classification import ClassificationDataset
import torch.optim as optim
import torch.nn as nn
import torch
Expand Down Expand Up @@ -86,7 +86,21 @@ def parse_folder(dataset_path):
print("An error occurred:", str(e))
return None

def train_model(model, train_data_loader, test_data_loader, num_epochs, learning_rate=0.001, checkpoint_dir='checkpoints', logger=None):
def train_model(model, train_data_loader, test_data_loader, num_epochs, learning_rate=0.001, checkpoint_dir='checkpoints', logger=None, device='cpu'):
'''Train a PyTorch model for a classification task.
Args:
model (nn.Module): Torch model to train.
train_data_loader (DataLoader): Training data loader.
test_data_loader (DataLoader): Testing data loader.
num_epochs (int): Number of epochs to train the model for.
learning_rate (float): Learning rate for the optimizer.
checkpoint_dir (str): Directory to save model checkpoints.
logger (Logger): Logger to log training details.
Returns:
None
'''

# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
Expand Down Expand Up @@ -123,6 +137,8 @@ def train_model(model, train_data_loader, test_data_loader, num_epochs, learning
if logger:
logger.info(f'Epoch {epoch + 1}, Validation Accuracy: {accuracy:.2f}%')



if accuracy > best_accuracy:
best_accuracy = accuracy
checkpoint_path = os.path.join(checkpoint_dir, f'model_epoch_{epoch + 1}.pth')
Expand Down
2 changes: 1 addition & 1 deletion UniTrain/utils/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from tqdm import tqdm
from torch.utils.data import DataLoader
from torchvision import transforms
from dataset.segmentation import SegmentationDataset
from ..dataset.segmentation import SegmentationDataset
import torchsummary

def get_data_loader(image_paths:list, mask_paths:list, batch_size:int, shuffle:bool=True, transform=None) -> DataLoader:
Expand Down
Binary file modified dist/UniTrain-0.1.tar.gz
Binary file not shown.
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
'torchvision',
'numpy',
'pandas',
'torchsummary',
],
keywords=['Deep Learning', 'Machine Learning', 'Training Framework'],
classifiers=[
Expand Down

0 comments on commit b40c386

Please sign in to comment.