Skip to content

Commit

Permalink
WIP: Mmappable Arrays (#582)
Browse files Browse the repository at this point in the history
* wip: mmappable arrays

* tests

* include mmap_test

* disable broken mmap on windows

* update warning test
  • Loading branch information
JonasIsensee authored Aug 23, 2024
1 parent 7f792e8 commit cc86e5d
Show file tree
Hide file tree
Showing 7 changed files with 250 additions and 23 deletions.
3 changes: 0 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,16 @@ MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
Mmap = "a63ad114-7e13-5084-954f-fe012c677804"
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
TranscodingStreams = "3bb67fe8-82b1-5028-8e26-92a6c54297fa"
UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
Unicode = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5"

[compat]
FileIO = "1"
MacroTools = "0.5"
Mmap = "1"
OrderedCollections = "1"
PrecompileTools = "1"
Reexport = "1"
Requires = "1"
TranscodingStreams = "0.9, 0.10, 0.11"
UUIDs = "1"
Expand Down
5 changes: 2 additions & 3 deletions src/JLD2.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
module JLD2
using OrderedCollections: OrderedDict
using Reexport: @reexport
using MacroTools: MacroTools, @capture
using Mmap: Mmap
using Unicode: Unicode
using TranscodingStreams: TranscodingStreams
@reexport using FileIO: load, save
using FileIO: load, save
export load, save
using Requires: @require
using PrecompileTools: @setup_workload, @compile_workload

Expand Down
8 changes: 8 additions & 0 deletions src/dataio.jl
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,14 @@ function write_data(io::IOStream, f::JLDFile, data::Array{T}, odr::Type{T}, ::Re
nothing
end

function write_data(io::IOStream, f::JLDFile, data, odr, _, wsession::JLDWriteSession)
buf = Vector{UInt8}(undef, odr_sizeof(odr))
cp = Ptr{Cvoid}(pointer(buf))
h5convert!(cp, odr, f, data, wsession)
unsafe_write(io, Ptr{UInt8}(pointer(buf)), odr_sizeof(odr))
nothing
end

function write_data(io::BufferedWriter, f::JLDFile, data::Array{T}, odr::S,
::DataMode, wsession::JLDWriteSession) where {T,S}
position = io.position[]
Expand Down
11 changes: 6 additions & 5 deletions src/datasets.jl
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ end
get_ndims_offset(f::JLDFile, dataspace::ReadDataspace, attributes::Nothing) =
(dataspace.dimensionality, dataspace.dimensions_offset)

function get_ndims_offset(f::JLDFile, dataspace::ReadDataspace, attributes::Vector{ReadAttribute})
function get_ndims_offset(f::JLDFile, dataspace::ReadDataspace, attributes::AbstractVector)
ndims = dataspace.dimensionality
offset = dataspace.dimensions_offset
if !isempty(attributes)
Expand Down Expand Up @@ -363,8 +363,7 @@ end
psz += CONTINUATION_MSG_SIZE

# Figure out the layout
# The simplest CompactStorageMessage only supports data sets < 2^16
if datasz < 8192 || (!(data isa Array) && datasz < typemax(UInt16))
if datasz == 0 || (!(data isa Array) && datasz < 8192)
layout_class = LcCompact
psz += jlsizeof(CompactStorageMessage) + datasz
elseif data isa Array && compress != false && isconcretetype(eltype(data)) && isbitstype(eltype(data))
Expand Down Expand Up @@ -420,11 +419,13 @@ end
f.end_of_data += length(deflated)
jlwrite(f.io, deflated)
else
jlwrite(cio, ContiguousStorageMessage(datasz, h5offset(f, f.end_of_data)))
data_address = f.end_of_data + 8 - mod1(f.end_of_data, 8)
jlwrite(cio, ContiguousStorageMessage(datasz, h5offset(f, data_address)))
jlwrite(cio, CONTINUATION_PLACEHOLDER)
jlwrite(io, end_checksum(cio))

f.end_of_data += datasz
f.end_of_data = data_address + datasz
seek(io, data_address)
write_data(io, f, data, odr, datamode(odr), wsession)
end

Expand Down
155 changes: 144 additions & 11 deletions src/explicit_datasets.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,35 @@ mutable struct Dataset
header_chunk_info # chunk_start, chunk_end, next_msg_offset
end


"""
create_dataset(parent, path, datatype, dataspace; kwargs...)
Arguments:
- `parent::Union{JLDfile, Group}`: Containing group of new dataset
- `path`: Path to new dataset relative to `parent`. If `path` is `nothing`, the dataset is unnamed.
- `datatype`: Datatype of new dataset (element type in case of arrays)
- `dataspace`: Dimensions or `Dataspace` of new dataset
Keyword arguments:
- `layout`: `DataLayout` of new dataset
- `filters`: `FilterPipeline` for describing the compression pipeline
"""
create_dataset(f::JLDFile, args...; kwargs...) = create_dataset(f.root_group, args...; kwargs...)
function create_dataset(
parent::Group,
name::Union{Nothing,String},
g::Group,
path::Union{Nothing,String},
datatype=nothing,
dataspace=nothing;
layout = nothing,
chunk=nothing,
filters=Filter[],
filters=FilterPipeline(),
)
if !isnothing(name)
(parent, name) = pathize(parent, name, true)
if !isnothing(path)
(parent, name) = pathize(g, path, true)
else
name = ""
parent = g.f

Check warning on line 43 in src/explicit_datasets.jl

View check run for this annotation

Codecov / codecov/patch

src/explicit_datasets.jl#L42-L43

Added lines #L42 - L43 were not covered by tests
end

return Dataset(parent, name, UNDEFINED_ADDRESS, datatype, dataspace,
Expand Down Expand Up @@ -119,6 +136,7 @@ function write_dataset(dataset::Dataset, data)
throw(ArgumentError("Invalid attribute: $a"))
end
io = f.io
odr = objodr(data)
datasz = odr_sizeof(odr)::Int * numel(dataspace)::Int

psz = payload_size_without_storage_message(dataspace, datatype)::Int
Expand All @@ -131,11 +149,11 @@ function write_dataset(dataset::Dataset, data)

# determine layout class
# DataLayout object is only available after the data is written
if datasz < 8192
if datasz == 0 || (!(data isa Array) && datasz < 8192)
layout_class = LcCompact
psz += jlsizeof(CompactStorageMessage) + datasz

elseif !isnothing(dataset.chunk) || !isempty(dataset.filters)
elseif !isnothing(dataset.chunk) || !isempty(dataset.filters.filters)

Check warning on line 156 in src/explicit_datasets.jl

View check run for this annotation

Codecov / codecov/patch

src/explicit_datasets.jl#L156

Added line #L156 was not covered by tests
# Do some additional checks on the data here
layout_class = LcChunked
# improve filter support here
Expand All @@ -144,7 +162,7 @@ function write_dataset(dataset::Dataset, data)
layout_class = LcContiguous
psz += jlsizeof(ContiguousStorageMessage)
end
fullsz = jlsizeof(ObjectStart) + size_size(psz) + psz + 4 # why do I need to correct here?
fullsz = jlsizeof(ObjectStart) + size_size(psz) + psz + 4

header_offset = f.end_of_data
seek(io, header_offset)
Expand Down Expand Up @@ -191,14 +209,18 @@ function write_dataset(dataset::Dataset, data)
jlwrite(f.io, end_checksum(cio))

else
jlwrite(cio, ContiguousStorageMessage(datasz, h5offset(f, f.end_of_data)))
# Align contiguous chunk to 8 bytes in the file
address = f.end_of_data + 8 - mod1(f.end_of_data, 8)
offset = h5offset(f, address)
jlwrite(cio, ContiguousStorageMessage(datasz, offset))

Check warning on line 215 in src/explicit_datasets.jl

View check run for this annotation

Codecov / codecov/patch

src/explicit_datasets.jl#L213-L215

Added lines #L213 - L215 were not covered by tests

dataset.header_chunk_info = (header_offset, position(cio)+20, position(cio))
# Add NIL message replacable by continuation message
jlwrite(io, CONTINUATION_PLACEHOLDER)
jlwrite(io, end_checksum(cio))

f.end_of_data += datasz
f.end_of_data = address + datasz
seek(io, address)

Check warning on line 223 in src/explicit_datasets.jl

View check run for this annotation

Codecov / codecov/patch

src/explicit_datasets.jl#L222-L223

Added lines #L222 - L223 were not covered by tests
write_data(io, f, data, odr, datamode(odr), wsession)
end

Expand Down Expand Up @@ -243,7 +265,7 @@ function get_dataset(f::JLDFile, offset::RelOffset, g=f.root_group, name="")
hmitr = HeaderMessageIterator(f, offset)
for msg in hmitr
if msg.type == HmDataspace
dset.dataspace = HmWrap(HmDataspace, msg)#ReadDataspace(f, msg)
dset.dataspace = HmWrap(HmDataspace, msg)
elseif msg.type == HmDatatype
dset.datatype = HmWrap(HmDatatype, msg).dt
elseif msg.type == HmDataLayout
Expand Down Expand Up @@ -411,4 +433,115 @@ function attributes(dset::Dataset; plain::Bool=false)
OrderedDict(keys(dset.attributes) .=> map(values(dset.attributes)) do attr
read_attr_data(dset.parent.f, attr)
end)
end

## Mmap Arrays
function ismmappable(dset::Dataset)
iswritten(dset) || return false
f = dset.parent.f
dt = dset.datatype
if dt isa SharedDatatype
rr = jltype(f, get(f.datatype_locations, dt.header_offset, dt))
else
rr = jltype(f, dt)
end
T = typeof(rr).parameters[1]
!(samelayout(T)) && return false
!isempty(dset.filters.filters) && return false
ret = false
if (layout = dset.layout) isa HmWrap{HmDataLayout}
ret = (layout.layout_class == LcContiguous && layout.data_address != UNDEFINED_ADDRESS)
end
if ret == true && Sys.iswindows() && dset.parent.f.writable
@warn "On Windows memory-mapping is only possible for files in read-only mode."
ret = false
end
return ret
end

function readmmap(dset::Dataset)
ismmappable(dset) || throw(ArgumentError("Dataset is not mmappable"))
f = dset.parent.f

# figure out the element type
dt = dset.datatype
if dt isa SharedDatatype
rr = jltype(f, get(f.datatype_locations, dt.header_offset, dt))
else
rr = jltype(f, dt)
end
T = typeof(rr).parameters[1]
ndims, offset = get_ndims_offset(f, ReadDataspace(f, dset.dataspace), collect(values(dset.attributes)))

io = f.io
seek(io, offset)
dims = [jlread(io, Int64) for i in 1:ndims]
iobackend = io isa IOStream ? io : io.f
seek(iobackend, DataLayout(f, dset.layout).data_offset)
return Mmap.mmap(iobackend, Array{T, Int(ndims)}, (reverse(dims)..., ))
end

@static if !Sys.iswindows()
function allocate_early(dset::Dataset, T::DataType)
iswritten(dset) && throw(ArgumentError("Dataset has already been written to file"))
# for this to work, require all information to be provided
isnothing(dset.datatype) && throw(ArgumentError("datatype must be provided"))
isnothing(dset.dataspace) && throw(ArgumentError("dataspace must be provided"))
datatype = dset.datatype
dataspace = dset.dataspace

f = dset.parent.f
attributes = map(collect(dset.attributes)) do (name, attr)
attr isa WrittenAttribute && return attr
return WrittenAttribute(f, name, attr)
throw(ArgumentError("Invalid attribute: $a"))

Check warning on line 497 in src/explicit_datasets.jl

View check run for this annotation

Codecov / codecov/patch

src/explicit_datasets.jl#L495-L497

Added lines #L495 - L497 were not covered by tests
end
writtenas = writeas(T)
odr_ = _odr(writtenas, T, odr(writtenas))
datasz = odr_sizeof(odr_)::Int * numel(dataspace)::Int
psz = payload_size_without_storage_message(dataspace, datatype)::Int
psz += sum(message_size.(attributes), init=0)
# minimum extra space for continuation message
psz += jlsizeof(HeaderMessage) + jlsizeof(RelOffset) + jlsizeof(Length)

# Layout class: Use contiguous for now
layout_class = LcContiguous
psz += jlsizeof(ContiguousStorageMessage)
fullsz = jlsizeof(ObjectStart) + size_size(psz) + psz + 4

header_offset = f.end_of_data
io = f.io
seek(io, header_offset)
f.end_of_data = header_offset + fullsz

cio = begin_checksum_write(io, fullsz - 4)
write_object_header_and_dataspace_message(cio, f, psz, dataspace)
write_datatype_message(cio, datatype)
for a in attributes
write_message(cio, f, a, wsession)
end

Check warning on line 522 in src/explicit_datasets.jl

View check run for this annotation

Codecov / codecov/patch

src/explicit_datasets.jl#L521-L522

Added lines #L521 - L522 were not covered by tests
# Align contiguous chunk to 8 bytes in the file
address = f.end_of_data + 8 - mod1(f.end_of_data, 8)
offset = h5offset(f, address)
jlwrite(cio, ContiguousStorageMessage(datasz, offset))

dset.header_chunk_info = (header_offset, position(cio)+20, position(cio))
# Add NIL message replacable by continuation message
jlwrite(io, CONTINUATION_PLACEHOLDER)
jlwrite(io, end_checksum(cio))

f.end_of_data = address + datasz
seek(io, f.end_of_data)

offset = h5offset(f, header_offset)
!isempty(dset.name) && (dset.parent[dset.name] = offset)
#dset.offset = offset

# load current dataset as new dataset
ddset = get_dataset(f, offset, dset.parent, dset.name)
for field in fieldnames(Dataset)
setproperty!(dset, field, getfield(ddset, field))
end
return offset
end
end
88 changes: 88 additions & 0 deletions test/mmap_test.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
using JLD2, Test

@testset "Mmapped Arrays" begin
cd(mktempdir()) do

a = rand(100,100);
b = rand(ComplexF64, 5,5)
c = 42
d = [ntuple(x->Bool(x%2), Val(24)) for i=1:100]

fn = "test.jld2"
jldsave(fn; a, b, c, d)

jldopen(fn, "r") do f
dset = JLD2.get_dataset(f, "a")
@test JLD2.ismmappable(dset)
@test JLD2.readmmap(dset) == a
dset = JLD2.get_dataset(f, "b")
@test JLD2.ismmappable(dset)
@test JLD2.readmmap(dset) == b
dset = JLD2.get_dataset(f, "c")
@test JLD2.ismmappable(dset) == false
dset = JLD2.get_dataset(f, "d")
@test JLD2.ismmappable(dset) == true
end

if Sys.iswindows()
jldopen(fn, "a") do f
dset = JLD2.get_dataset(f, "a")
@test JLD2.ismmappable(dset) == false
@test_logs (:warn, "On Windows memory-mapping is only possible for files in read-only mode.") JLD2.ismmappable(dset)
dset = JLD2.get_dataset(f, "c")
@test JLD2.ismmappable(dset) == false
@test_nowarn JLD2.ismmappable(dset)
end
else
jldopen(fn, "a") do f
dset = JLD2.get_dataset(f, "a")
@test JLD2.ismmappable(dset)
@test JLD2.readmmap(dset) == a
JLD2.readmmap(dset)[1,1] = 42.0

dset = JLD2.get_dataset(f, "b")
@test JLD2.ismmappable(dset)
@test JLD2.readmmap(dset) == b
JLD2.readmmap(dset)[1,1] = 4.0 + 2.0im

dset = JLD2.get_dataset(f, "c")
@test JLD2.ismmappable(dset) == false

dset = JLD2.get_dataset(f, "d")
@test JLD2.ismmappable(dset) == true
end

jldopen(fn, "r") do f
@test f["a"][1,1] == 42.0
@test f["b"][1,1] == 4.0 + 2.0im
@test f["d"] == d
end
end
end
end

if !Sys.iswindows()
@testset "Early Allocation" begin
# Update this for proper API eventually
jldopen(fn, "w") do f
dset = JLD2.create_dataset(f, "data")

dset.datatype = JLD2.h5fieldtype(f, Float64, Float64, Val{false})

dims = (100,100)
dset.dataspace = JLD2.WriteDataspace(JLD2.DS_SIMPLE, UInt64.(reverse(dims)), ())

JLD2.allocate_early(dset, Float64)

@test JLD2.ismmappable(dset)

emptyarr = JLD2.readmmap(dset)

emptyarr[1:2:100] .= 1:50
end

data = JLD2.load(fn, "data")
@test all(data[2:2:100] .== 0.0)
@test all(data[1:2:100] .== 1:50)
end
end
Loading

0 comments on commit cc86e5d

Please sign in to comment.