Skip to content

Commit

Permalink
Merge pull request #83 from ShuhuaGao/MNIST-display
Browse files Browse the repository at this point in the history
render black digits on a whiteboard
  • Loading branch information
CarloLucibello authored Nov 7, 2021
2 parents 3e15931 + 055218e commit d184e5c
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 6 deletions.
18 changes: 12 additions & 6 deletions src/MNIST/utils.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
"""
convert2image(array) -> Array{Gray}
convert2image(array; black_digits=false) -> Array{Gray}
Convert the given MNIST horizontal-major tensor (or feature matrix)
to a vertical-major `Colorant` array. The values are also color
to a vertical-major `Colorant` array. If `black_digits` is `true`, the values are also color
corrected according to the website's description, which means that
the digits are black on a white background.
Expand All @@ -16,17 +16,23 @@ julia> MNIST.convert2image(MNIST.traintensor(1)) # first training image
[...]
```
"""
function convert2image(array::AbstractArray{T}) where {T<:Number}
function convert2image(array::AbstractArray{T}; black_digits::Bool=false) where {T<:Number}
nlast = size(array)[end]
array = reshape(array, 28, 28, :)
array = permutedims(array, (2, 1, 3))
if size(array)[end] == 1 && nlast != 1
array = dropdims(array, dims=3)
end
if any(x -> x > 1, array) # simple check if x in [0,1]
img = _colorview(Gray, array ./ T(255))
array = array ./ T(255) # avoid changing the input array
if black_digits
array .= one(eltype(array)) .- array
end
else
img = _colorview(Gray, array)
if black_digits
array = one(eltype(array)) .- array
end
end
img

return _colorview(Gray, array)
end
11 changes: 11 additions & 0 deletions test/tst_mnist.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,17 @@ end
@test size(A) == (28,28,2)
@test eltype(A) == Gray{N0f8}
@test MNIST.convert2image(vec(data)) == A

# test black digits and white background
data = rand(N0f8,28,28,2)
data[1] = 0
data[3, 3, 2] = 0.4
A = MNIST.convert2image(data; black_digits=true)
@test A[1] == 1
@test A[3, 3, 2] == 0.6
@test size(A) == (28,28,2)
@test eltype(A) == Gray{N0f8}
@test MNIST.convert2image(vec(data); black_digits=true) == A
end

# NOT executed on CI. only executed locally.
Expand Down

0 comments on commit d184e5c

Please sign in to comment.