@@ -91,8 +91,10 @@ def main():
9191 help = 'learning rate (default: 0.1)' )
9292 parser .add_argument ('--gamma' , type = float , default = 0.7 , metavar = 'M' ,
9393 help = 'learning rate step gamma (default: 0.7)' )
94- parser .add_argument ('--no-cuda' , action = 'store_true' , default = False ,
95- help = 'disables CUDA training' )
94+ parser .add_argument ('--cuda' , action = 'store_true' , default = False ,
95+ help = 'enables CUDA training' )
96+ parser .add_argument ('--mps' , action = "store_true" , default = False ,
97+ help = "enables MPS training" )
9698 parser .add_argument ('--dry-run' , action = 'store_true' , default = False ,
9799 help = 'quickly check a single pass' )
98100 parser .add_argument ('--seed' , type = int , default = 1 , metavar = 'S' ,
@@ -102,13 +104,19 @@ def main():
102104 parser .add_argument ('--save-model' , action = 'store_true' , default = False ,
103105 help = 'for Saving the current Model' )
104106 args = parser .parse_args ()
105- use_cuda = not args .no_cuda and torch .cuda .is_available ()
106107
107- torch .manual_seed (args .seed )
108+ if args .cuda and not args .mps :
109+ device = "cuda"
110+ elif args .mps and not args .cuda :
111+ device = "mps"
112+ else :
113+ device = "cpu"
114+
115+ device = torch .device (device )
108116
109- device = torch .device ( "cuda" if use_cuda else "cpu" )
117+ torch .manual_seed ( args . seed )
110118
111- kwargs = {'num_workers' : 1 , 'pin_memory' : True } if use_cuda else {}
119+ kwargs = {'num_workers' : 1 , 'pin_memory' : True } if args . cuda else {}
112120 train_loader = torch .utils .data .DataLoader (
113121 datasets .MNIST ('../data' , train = True , download = True ,
114122 transform = transforms .Compose ([
0 commit comments