|
| 1 | +import numpy as np |
| 2 | +import matplotlib.pyplot as plt |
| 3 | +import torch |
| 4 | +import torch.nn as nn |
| 5 | +import torch.nn.functional as F |
| 6 | +import torch.optim as optim |
| 7 | +from torchsummary import summary |
| 8 | +from torchvision import transforms,datasets |
| 9 | + |
| 10 | + |
| 11 | +class CIFAR10(nn.Module): |
| 12 | + def __init__(self): |
| 13 | + super(CIFAR10, self).__init__() |
| 14 | + self.conv1 = nn.Conv2d(3, 64, 3, 1) |
| 15 | + self.conv2 = nn.Conv2d(64, 64, 3, 1) |
| 16 | + self.conv3= nn.Conv2d(64, 128, 3, 1) |
| 17 | + self.conv4= nn.Conv2d(128, 128, 3, 1) |
| 18 | + self.dropout1 = nn.Dropout2d(0.5) |
| 19 | + self.fc1 = nn.Linear(3200, 256) |
| 20 | + self.fc2 = nn.Linear(256, 256) |
| 21 | + self.fc3 = nn.Linear(256, 10) |
| 22 | + |
| 23 | + def forward(self, x): |
| 24 | + x = self.conv1(x) |
| 25 | + x = F.relu(x) |
| 26 | + x = self.conv2(x) |
| 27 | + x = F.relu(x) |
| 28 | + x = F.max_pool2d(x, 2) |
| 29 | + x = self.conv3(x) |
| 30 | + x = F.relu(x) |
| 31 | + x = self.conv4(x) |
| 32 | + x = F.relu(x) |
| 33 | + x = F.max_pool2d(x, 2) |
| 34 | + x = torch.flatten(x,1) |
| 35 | + x = self.fc1(x) |
| 36 | + x = F.relu(x) |
| 37 | + x = self.dropout1(x) |
| 38 | + x = self.fc2(x) |
| 39 | + x = F.relu(x) |
| 40 | + x = self.fc3(x) |
| 41 | + return x |
| 42 | + |
| 43 | +def fit(model,device,train_loader,val_loader,optimizer,criterion,epochs): |
| 44 | + data_loader = {'train':train_loader,'val':val_loader} |
| 45 | + print("Fitting the model...") |
| 46 | + train_loss,val_loss=[],[] |
| 47 | + train_acc,val_acc=[],[] |
| 48 | + for epoch in range(epochs): |
| 49 | + loss_per_epoch,val_loss_per_epoch=0,0 |
| 50 | + acc_per_epoch,val_acc_per_epoch,total,val_total=0,0,0,0 |
| 51 | + for phase in ('train','val'): |
| 52 | + for i,data in enumerate(data_loader[phase]): |
| 53 | + inputs,labels = data[0].to(device),data[1].to(device) |
| 54 | + outputs = model(inputs) |
| 55 | + #preding classes of one batch |
| 56 | + preds = torch.max(outputs,1)[1] |
| 57 | + #calculating loss on the output of one batch |
| 58 | + loss = criterion(outputs,labels) |
| 59 | + if phase == 'train': |
| 60 | + acc_per_epoch+=(labels==preds).sum().item() |
| 61 | + total+= labels.size(0) |
| 62 | + optimizer.zero_grad() |
| 63 | + #grad calc w.r.t Loss func |
| 64 | + loss.backward() |
| 65 | + #update weights |
| 66 | + optimizer.step() |
| 67 | + loss_per_epoch+=loss.item() |
| 68 | + else: |
| 69 | + val_acc_per_epoch+=(labels==preds).sum().item() |
| 70 | + val_total+=labels.size(0) |
| 71 | + val_loss_per_epoch+=loss.item() |
| 72 | + print("Epoch: {} Loss: {:0.6f} Acc: {:0.6f} Val_Loss: {:0.6f} Val_Acc: {:0.6f}".format(epoch+1,loss_per_epoch/len(train_loader),acc_per_epoch/total,val_loss_per_epoch/len(val_loader),val_acc_per_epoch/val_total)) |
| 73 | + train_loss.append(loss_per_epoch/len(train_loader)) |
| 74 | + val_loss.append(val_loss_per_epoch/len(val_loader)) |
| 75 | + train_acc.append(acc_per_epoch/total) |
| 76 | + val_acc.append(val_acc_per_epoch/val_total) |
| 77 | + return train_loss,val_loss,train_acc,val_acc |
| 78 | + |
| 79 | +if __name__=='__main__': |
| 80 | + np.random.seed(42) |
| 81 | + torch.manual_seed(42) |
| 82 | + |
| 83 | + transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,0.5,0.5), (1.0,1.0,1.0))]) |
| 84 | + dataset = datasets.CIFAR10(root = './data', train=True, transform = transform, download=True) |
| 85 | + train_set, val_set = torch.utils.data.random_split(dataset, [45000, 5000]) |
| 86 | + train_loader = torch.utils.data.DataLoader(train_set,batch_size=128,shuffle=True) |
| 87 | + val_loader = torch.utils.data.DataLoader(val_set,batch_size=128,shuffle=True) |
| 88 | + |
| 89 | + use_cuda=True |
| 90 | + device = torch.device("cuda" if (use_cuda and torch.cuda.is_available()) else "cpu") |
| 91 | + |
| 92 | + model = CIFAR10().to(device) |
| 93 | + summary(model,(3,32,32)) |
| 94 | + |
| 95 | + optimizer = optim.SGD(model.parameters(),lr=0.01, momentum=0.9, nesterov=True, weight_decay=1e-6) |
| 96 | + criterion = nn.CrossEntropyLoss() |
| 97 | + |
| 98 | + train_loss,val_loss,train_acc,val_acc=fit(model,device,train_loader,val_loader,optimizer,criterion,50) |
| 99 | + |
| 100 | + fig = plt.figure(figsize=(5,5)) |
| 101 | + plt.plot(np.arange(1,51), train_loss, "*-",label="Training Loss") |
| 102 | + plt.plot(np.arange(1,51), val_loss,"o-",label="Val Loss") |
| 103 | + plt.xlabel("Num of epochs") |
| 104 | + plt.legend() |
| 105 | + plt.savefig('cifar10_model_loss_event.png') |
| 106 | + |
| 107 | + fig = plt.figure(figsize=(5,5)) |
| 108 | + plt.plot(np.arange(1,51), train_acc, "*-",label="Training Acc") |
| 109 | + plt.plot(np.arange(1,51), val_acc,"o-",label="Val Acc") |
| 110 | + plt.xlabel("Num of epochs") |
| 111 | + plt.legend() |
| 112 | + plt.savefig('cifar10_model_accuracy_event.png') |
| 113 | + |
| 114 | + torch.save(model.state_dict(),'./models/cifar10_model.pt') |
0 commit comments