Skip to content

Commit

Permalink
support vlen arrays in compound datasets
Browse files Browse the repository at this point in the history
  • Loading branch information
kleinhenz committed Feb 29, 2020
1 parent 4e3f6c9 commit dfe1667
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 9 deletions.
23 changes: 17 additions & 6 deletions src/HDF5.jl
Original file line number Diff line number Diff line change
Expand Up @@ -486,6 +486,12 @@ struct FixedString{N}
end
length(::Type{FixedString{N}}) where N = N

struct VariableArray{T}
len::Csize_t
p::Ptr{Cvoid}
end
eltype(::Type{VariableArray{T}}) where T = T

# VLEN objects
struct HDF5Vlen{T}
data
Expand Down Expand Up @@ -1459,7 +1465,6 @@ function getindex(parent::Union{HDF5File, HDF5Group, HDF5Dataset}, r::HDF5Refere
h5object(obj_id, parent)
end


# convert Cstring/FixedString to String
function normalize_types(x::NamedTuple{T}) where T

Expand All @@ -1471,6 +1476,8 @@ function normalize_types(x::NamedTuple{T}) where T
join(Char.(xi.data))
elseif Ti <: FixedArray
reshape(collect(xi.data), size(Ti)...)
elseif Ti <: VariableArray
copy(unsafe_wrap(Array, convert(Ptr{eltype(xi)}, xi.p), xi.len, own=false))
elseif Ti <: NamedTuple
normalize_types(xi)
else
Expand Down Expand Up @@ -1504,10 +1511,10 @@ function read(dset::HDF5Dataset, T::Union{Type{Array{U}}, Type{U}}) where U <: N
HDF5.h5d_read(dset.id, memtype_id, HDF5.H5S_ALL, HDF5.H5S_ALL, HDF5.H5P_DEFAULT, buf)

types = get_all_types(U)
normalize = any(t -> t <: Union{Cstring, FixedString, FixedArray}, types)
normalize = any(t -> t <: Union{Cstring, FixedString, FixedArray, VariableArray}, types)
out = normalize ? normalize_types.(buf) : buf

reclaim = any(t -> t <: Cstring, types)
reclaim = any(t -> t <: Union{Cstring, VariableArray}, types)
if reclaim
dspace = dataspace(dset)
# NOTE I have seen this call fail but I cannot reproduce
Expand Down Expand Up @@ -1570,6 +1577,7 @@ function read(obj::DatasetOrAttribute, ::Type{HDF5Vlen{T}}) where {T<:Union{HDF5
for i = 1:len
h = structbuf[i]
data[i] = p2a(convert(Ptr{T}, h.p), Int(h.len))

end
data
end
Expand Down Expand Up @@ -2011,15 +2019,18 @@ function hdf5_to_julia_eltype(objtype)
dtype = HDF5Datatype(h5t_get_member_type(objtype.id, i-1))
ci = h5t_get_class(dtype.id)

if ci != H5T_STRING
return hdf5_to_julia_eltype(dtype)
else
if ci == H5T_STRING
if h5t_is_variable_str(dtype.id)
return Cstring
else
n = h5t_get_size(dtype.id)
return FixedString{Int(n)}
end
elseif ci == H5T_VLEN
superid = h5t_get_super(dtype.id)
T = VariableArray{hdf5_to_julia_eltype(HDF5Datatype(superid))}
else
return hdf5_to_julia_eltype(dtype)
end
end

Expand Down
22 changes: 19 additions & 3 deletions test/compound.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,24 @@ struct foo
b::String
c::String
d::Array{ComplexF64,2}
e::Array{Int64,1}
end

struct foo_hdf5
a::Float64
b::Cstring
c::NTuple{10, Cchar}
d::NTuple{9, ComplexF64}
e::HDF5.Hvl_t
end

function unsafe_convert(::Type{foo_hdf5}, x::foo)
foo_hdf5(x.a, Base.unsafe_convert(Cstring, x.b), ntuple(i -> x.c[i], length(x.c)), ntuple(i -> x.d[i], length(x.d)))
foo_hdf5(x.a,
Base.unsafe_convert(Cstring, x.b),
ntuple(i -> x.c[i], length(x.c)),
ntuple(i -> x.d[i], length(x.d)),
HDF5.Hvl_t(convert(Csize_t, length(x.e)), convert(Ptr{Cvoid}, pointer(x.e)))
)
end

function datatype(::Type{foo_hdf5})
Expand All @@ -39,12 +46,21 @@ function datatype(::Type{foo_hdf5})
array_dtype = HDF5.h5t_array_create(datatype(ComplexF64).id, 2, hsz)
HDF5.h5t_insert(dtype, "d", fieldoffset(foo_hdf5, 4), array_dtype)

vlen_dtype = HDF5.h5t_vlen_create(datatype(Int64))
HDF5.h5t_insert(dtype, "e", fieldoffset(foo_hdf5, 5), vlen_dtype)

HDF5Datatype(dtype)
end

@testset "compound" begin
N = 10
v = [foo(rand(), randstring(rand(10:100)), randstring(10), rand(ComplexF64, 3,3)) for _ in 1:N]
v = [foo(rand(),
randstring(rand(10:100)),
randstring(10),
rand(ComplexF64, 3,3),
rand(1:10, rand(10:100))
)
for _ in 1:N]
v_write = unsafe_convert.(foo_hdf5, v)

fn = tempname()
Expand All @@ -56,7 +72,7 @@ end
end

v_read = h5read(fn, "data")
for field in (:a, :b, :c, :d)
for field in (:a, :b, :c, :d, :e)
f = x -> getfield(x, field)
@test f.(v) == f.(v_read)
end
Expand Down

0 comments on commit dfe1667

Please sign in to comment.