From b6a9505e6540f725cb79037282cf3248ea09ed0e Mon Sep 17 00:00:00 2001 From: MasterSkepticista Date: Thu, 16 Jan 2025 15:01:45 +0530 Subject: [PATCH] Download mnist globally under /tmp Signed-off-by: MasterSkepticista --- openfl-workspace/torch_cnn_mnist/src/dataloader.py | 2 +- openfl-workspace/torch_cnn_mnist_fed_eval/src/mnist_utils.py | 2 +- .../torch_cnn_mnist_straggler_check/src/mnist_utils.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/openfl-workspace/torch_cnn_mnist/src/dataloader.py b/openfl-workspace/torch_cnn_mnist/src/dataloader.py index 3f3eeeb0bb..4512b521a8 100644 --- a/openfl-workspace/torch_cnn_mnist/src/dataloader.py +++ b/openfl-workspace/torch_cnn_mnist/src/dataloader.py @@ -121,7 +121,7 @@ def _load_raw_datashards(shard_num, collaborator_count, transform=None): 2 tuples: (image, label) of the training, validation dataset """ train_data, val_data = ( - datasets.MNIST("data", train=train, download=True, transform=transform) + datasets.MNIST("/tmp/mnist", train=train, download=True, transform=transform) for train in (True, False) ) X_train_tot, y_train_tot = train_data.train_data, train_data.train_labels diff --git a/openfl-workspace/torch_cnn_mnist_fed_eval/src/mnist_utils.py b/openfl-workspace/torch_cnn_mnist_fed_eval/src/mnist_utils.py index 95fa35fa6f..ef7a33a439 100644 --- a/openfl-workspace/torch_cnn_mnist_fed_eval/src/mnist_utils.py +++ b/openfl-workspace/torch_cnn_mnist_fed_eval/src/mnist_utils.py @@ -41,7 +41,7 @@ def _load_raw_datashards(shard_num, collaborator_count, transform=None): 2 tuples: (image, label) of the training, validation dataset """ train_data, val_data = ( - datasets.MNIST('data', train=train, download=True, transform=transform) + datasets.MNIST('/tmp/mnist', train=train, download=True, transform=transform) for train in (True, False) ) X_train_tot, y_train_tot = train_data.train_data, train_data.train_labels diff --git a/openfl-workspace/torch_cnn_mnist_straggler_check/src/mnist_utils.py b/openfl-workspace/torch_cnn_mnist_straggler_check/src/mnist_utils.py index a03e1e6da2..24ed5bc794 100644 --- a/openfl-workspace/torch_cnn_mnist_straggler_check/src/mnist_utils.py +++ b/openfl-workspace/torch_cnn_mnist_straggler_check/src/mnist_utils.py @@ -41,7 +41,7 @@ def _load_raw_datashards(shard_num, collaborator_count, transform=None): 2 tuples: (image, label) of the training, validation dataset """ train_data, val_data = ( - datasets.MNIST('data', train=train, download=True, transform=transform) + datasets.MNIST('/tmp/mnist', train=train, download=True, transform=transform) for train in (True, False) ) X_train_tot, y_train_tot = train_data.train_data, train_data.train_labels