diff --git a/src/flash/core/data/utilities/loading.py b/src/flash/core/data/utilities/loading.py index 0075309c8d..cc5e5067db 100644 --- a/src/flash/core/data/utilities/loading.py +++ b/src/flash/core/data/utilities/loading.py @@ -72,7 +72,13 @@ def _load_image_from_image(file): def _load_image_from_numpy(file): - return Image.fromarray(np.load(file).astype("uint8")).convert("RGB") + arr = np.load(file) + if not (arr == arr.astype("uint8")).all(): + # Max pixel value -> 255, min -> 0 + low = arr.min() + arr = 255 * (arr - low) / (arr.max() - low) + + return Image.fromarray(arr.astype("uint8")).convert("RGB") def _load_spectrogram_from_image(file):