Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding Stacked_MNIST #230 #231

Open
wants to merge 15 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ version = "0.7.14"
[deps]
CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b"
Chemfiles = "46823bd8-5fb3-5f92-9aa0-96921f3dd015"
Colors = "5ae59095-9a9b-59fe-a467-6f913c188581"
DataDeps = "124859b0-ceae-595e-8997-d05f6a7a8dfe"
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
DelimitedFiles = "8bb1440f-4735-579b-a4ab-409b98df4dab"
Expand All @@ -22,6 +23,7 @@ MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
NPZ = "15e1cf62-19b3-5cfa-8e77-841668bca605"
Pickle = "fbb45041-c46e-462f-888f-7c521cafbc2c"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Expand Down
1 change: 1 addition & 0 deletions docs/src/datasets/vision.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ CIFAR100
EMNIST
FashionMNIST
MNIST
StackedMNIST
Omniglot
SVHN2
```
5 changes: 5 additions & 0 deletions src/MLDatasets.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ using FileIO
import CSV
using LazyModules: @lazy
using Statistics
using Random
using Colors

include("require.jl") # export @require

Expand Down Expand Up @@ -90,6 +92,8 @@ export FashionMNIST
include("datasets/vision/mnist_reader/MNISTReader.jl")
include("datasets/vision/mnist.jl")
export MNIST
include("datasets/vision/stacked_mnist.jl")
export StackedMNIST
include("datasets/vision/omniglot.jl")
export Omniglot
include("datasets/vision/svhn2.jl")
Expand Down Expand Up @@ -175,6 +179,7 @@ function __init__()
__init__emnist()
__init__fashionmnist()
__init__mnist()
__init__stackedmist()
__init__omniglot()
__init__svhn2()

Expand Down
160 changes: 160 additions & 0 deletions src/datasets/vision/stacked_mnist.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
function __init__stackedmist()
DEPNAME = "StackedMNIST"
TRAINIMAGES = "train-images-idx3-ubyte.gz"
TRAINLABELS = "train-labels-idx1-ubyte.gz"
TESTIMAGES = "t10k-images-idx3-ubyte.gz"
TESTLABELS = "t10k-labels-idx1-ubyte.gz"
register(DataDep(DEPNAME,
"""Dataset: The Stacked MNIST dataset is derived from the standard MNIST dataset with an increased number of discrete modes. 240,000 RGB images in the size of 28×28 are synthesized by stacking three random digit images from MNIST along the color channel, resulting in 1,000 explicit modes in a uniform distribution corresponding to the number of possible triples of digits.
Authors: Metz et al.
Website: https://paperswithcode.com/dataset/stacked-mnist

[Metz L et al., 2016]
Metz L, Poole B, Pfau D, Sohl-Dickstein J. Unrolled generative adversarial networks. arXiv preprint arXiv:1611.02163. 2016 Nov 7.
""",
"",
[TRAINIMAGES, TRAINLABELS, TESTIMAGES, TESTLABELS]
))
end

"""
StackedMNIST(; Tx=Float32, split=:train, dir=nothing)
StackedMNIST([Tx, split])

The StackedMNIST dataset is a variant of the classic MNIST dataset where each observation is a combination of three randomly shuffled MNIST digits, stacked as RGB channels.

# Arguments

- `Tx`: The data type for the features. Defaults to `Float32`. If `Tx <: Integer`, the features will range between 0 and 255; otherwise, they will be scaled between 0 and 1.
- `split`: The data partition to load, either `:train` or `:test`. Defaults to `:train`.
- `dir`: The directory where the dataset is stored. If `nothing`, the default location is used.

# Fields

- `features`: A 4D array of MNIST images with dimensions `(28, 28, 3, num_images)`, where `num_images` is the number of images in the selected split.
- `targets`: A vector of tuples, each containing three integers representing the combined labels for the stacked RGB image.
- `size`: The total number of images in the dataset.

# Methods

- `convert2image`: Converts feature arrays to RGB images.
- `Base.length(sm::StackedMNIST)`: Returns the number of images in the dataset.
- `Base.getindex(sm::StackedMNIST, idx::Int)`: Returns the RGB image and its corresponding target label at the specified index.

# Examples

The images in the StackedMNIST dataset are loaded as a multi-dimensional array of type `Tx`. The dataset's `features` field is a 4D array in WHCN format (width, height, channels, num_images). Labels are stored as a vector of tuples in `StackedMNIST().targets`. The images are constructed by stacking three randomly chosen MNIST digits as RGB channels, resulting in 1,000 explicit modes corresponding to the number of possible triples of digits.

```julia-repl
julia> using MLDatasets: StackedMNIST

julia> dataset = StackedMNIST(:train)
StackedMNIST:
features => 28×28×3×60000 Array{Float32, 4}
targets => 60000-element Vector{Tuple{Int, Int, Int}}

julia> dataset[1:5].targets
5-element Vector{Tuple{Int, Int, Int}}:
(7, 2, 1)
(2, 3, 8)
(1, 5, 3)
(4, 0, 9)
(7, 4, 5)

julia> img, label = dataset[1]
RGB Image with dimensions 28×28, label: (7, 2, 1)

julia> dataset = StackedMNIST(UInt8, :test)
StackedMNIST:
features => 28×28×3×10000 Array{UInt8, 4}
split => :test
targets => 10000-element Vector{Tuple{Int, Int, Int}}
```
"""
struct StackedMNIST <: SupervisedDataset
features::Any
split::Symbol
targets::Vector{Tuple{Int, Int, Int}}
size::Int
end

# Convenience constructors for StackedMNIST
function StackedMNIST(; split = :train, Tx = Float32, size = 60000, dir = nothing)
StackedMNIST(Tx, split; size, dir)
end
StackedMNIST(split::Symbol; kws...) = StackedMNIST(; split, kws...)
StackedMNIST(Tx::Type; kws...) = StackedMNIST(; Tx, kws...)
function StackedMNIST(size::Integer; split = :train, Tx = Float32, dir = nothing)
StackedMNIST(Tx, split; size = size, dir = dir)
end

function StackedMNIST(
Tx::Type,
split::Symbol = :train,
; size = 60000, dir = nothing)
mnist = MNIST(Tx, split; dir = dir)
split = mnist.split

mnist_targets = vec(mnist.targets)
targets = Vector{Tuple{Int, Int, Int}}(undef, size)
features = Array{Tx, 4}(undef, 28, 28, 3, size)
# Randomly select 3 numbers from the list 60,000 times and store them as tuples

function random_three_unique(vec)
indices = randperm(length(vec))[1:3]
return (vec[indices[1]], vec[indices[2]], vec[indices[3]])
end

for i in 1:size
label1, label2, label3 = random_three_unique(mnist_targets)
index1 = findall(x -> x == label1, mnist_targets)
random_index1 = rand(index1)
red_channel = mnist.features[:, :, random_index1]

index2 = findall(x -> x == label2, mnist_targets)
random_index2 = rand(index2)
green_channel = mnist.features[:, :, random_index2]

index3 = findall(x -> x == label3, mnist_targets)
random_index3 = rand(index3)
blue_channel = mnist.features[:, :, random_index3]

targets[i] = label1, label2, label3
# Combine the channels into an RGB image and store in the features array
features[:, :, 1, i] = red_channel
features[:, :, 2, i] = green_channel
features[:, :, 3, i] = blue_channel
end

StackedMNIST(features, split, targets, size)
end

# Define the length function
Base.length(sm::StackedMNIST) = sm.size

# Define the getindex function
function Base.getindex(sm::StackedMNIST, idx::Int)
return (features = sm.features[:, :, :, idx], targets = sm.targets[idx])
end

# Function to extract and show an RGB image
function show_rgb_image(features, index)
red_channel = features[:, :, 1, index] # Extract and convert red channel
green_channel = features[:, :, 2, index] # Extract and convert green channel
blue_channel = features[:, :, 3, index] # Extract and convert blue channel

img_rgb = Colors.RGB.(red_channel, green_channel, blue_channel) # Combine into RGB image
return img_rgb # Plot as an RGB image
end

function convert2image(::Type{<:StackedMNIST}, x::AbstractArray{<:Integer})
# Reinterpret the input array as N0f8 and convert it to StackedMNIST-compatible format
return convert2image(StackedMNIST, reinterpret(N0f8, convert(Array{UInt8}, x)))
end

function convert2image(::Type{<:StackedMNIST}, x::AbstractArray{T, N}) where {T, N}
@assert N == 3 || N == 4
x = permutedims(x, (2, 1, 3:N...))
img_rgb = Colors.RGB{T}.(x[:, :, 1, :], x[:, :, 2, :], x[:, :, 3, :])
return reshape(img_rgb, size(img_rgb, 1), size(img_rgb, 2), size(img_rgb, 3))
end
14 changes: 9 additions & 5 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ dataset_tests = [
"datasets/text.jl",
"datasets/vision/fashion_mnist.jl",
"datasets/vision/mnist.jl",
"datasets/vision/stacked_mnist.jl"
]

no_ci_dataset_tests = [
Expand All @@ -29,7 +30,7 @@ no_ci_dataset_tests = [
"datasets/vision/emnist.jl",
"datasets/vision/omniglot.jl",
"datasets/vision/svhn2.jl",
"datasets/meshes.jl",
"datasets/meshes.jl"
]

@assert isempty(intersect(dataset_tests, no_ci_dataset_tests))
Expand All @@ -39,11 +40,12 @@ container_tests = [
# "containers/tabledataset.jl",
# "containers/hdf5dataset.jl",
# "containers/jld2dataset.jl",
"containers/cacheddataset.jl",
"containers/cacheddataset.jl"
]

@testset "Datasets" begin
@testset "$(split(t,"/")[end])" for t in dataset_tests
@info "Including $t"
include(t)
end

Expand All @@ -57,8 +59,10 @@ container_tests = [
end
end

@testset "Containers" begin for t in container_tests
include(t)
end end
@testset "Containers" begin
for t in container_tests
include(t)
end
end

nothing
Loading