diff --git a/Project.toml b/Project.toml index 268274b..b8a5449 100644 --- a/Project.toml +++ b/Project.toml @@ -2,7 +2,7 @@ name = "JLSO" uuid = "9da8a3cd-07a3-59c0-a743-3fdc52c30d11" license = "MIT" authors = ["Invenia Technical Computing Corperation"] -version = "2.5.0" +version = "2.6.0" [deps] BSON = "fbb218c0-5317-5bc6-957e-2ee96dd4b1f0" diff --git a/src/JLSOFile.jl b/src/JLSOFile.jl index 1bcba1d..fcd3b9e 100644 --- a/src/JLSOFile.jl +++ b/src/JLSOFile.jl @@ -75,8 +75,9 @@ function JLSOFile(; image=_image(), kwargs... ) + data = isempty(kwargs) ? Dict{Symbol,Any}() : Dict(kwargs) return JLSOFile( - Dict(kwargs); + data; version=version, julia=julia, format=format, @@ -117,3 +118,18 @@ function Base.:(==)(a::JLSOFile, b::JLSOFile) end Base.names(jlso::JLSOFile) = collect(keys(jlso.objects)) + +Base.keys(jlso::JLSOFile) = keys(jlso.objects) + +Base.haskey(jlso::JLSOFile, key) = haskey(jlso.objects, key) + +Base.get(jlso::JLSOFile, key, default) = haskey(jlso, key) ? jlso[key] : default + +Base.get!(jlso::JLSOFile, key, default) = get!(() -> default, jlso, key) +function Base.get!(func, jlso::JLSOFile, key) + return if haskey(jlso, key) + jlso[key] + else + jlso[key] = func() + end +end diff --git a/src/file_io.jl b/src/file_io.jl index 56b18d5..823624d 100644 --- a/src/file_io.jl +++ b/src/file_io.jl @@ -105,7 +105,8 @@ end Creates a JLSOFile with the specified data and kwargs and writes it back to the io. """ -save(io::IO, data; kwargs...) = write(io, JLSOFile(data; kwargs...)) +save(io::IO, data::JLSOFile) = write(io, data) +save(io::IO, data; kwargs...) = save(io, JLSOFile(data; kwargs...)) save(io::IO, data::Pair...; kwargs...) = save(io, Dict(data...); kwargs...) function save(path::Union{AbstractPath, AbstractString}, args...; kwargs...) return open(io -> save(io, args...; kwargs...), path, "w") diff --git a/test/JLSOFile.jl b/test/JLSOFile.jl index f5bff60..cb659a6 100644 --- a/test/JLSOFile.jl +++ b/test/JLSOFile.jl @@ -22,6 +22,13 @@ @test jlso[:b] == "hello" @test haskey(jlso.manifest, "BSON") end + + @testset "no-arg constructor" begin + jlso = JLSOFile() + @test jlso isa JLSOFile + @test isempty(jlso.objects) + @test haskey(jlso.manifest, "BSON") + end end @testset "unknown format" begin @@ -55,3 +62,27 @@ end @test Pkg.TOML.parsefile(joinpath(d, "Manifest.toml")) == jlso.manifest end end + +@testset "keys/haskey" begin + jlso = JLSOFile(:string => datas[:String]) + @test collect(keys(jlso)) == [:string] + @test haskey(jlso, :string) + @test !haskey(jlso, :other) +end + +@testset "get/get!" begin + v = datas[:String] + jlso = JLSOFile(:str => v) + @test get(jlso, :str, "fail") == v + @test get!(jlso, :str, "fail") == v + + @test get(jlso, :other, v) == v + @test !haskey(jlso, :other) + + @test get!(jlso, :other, v) == v + @test jlso[:other] == v + + # key must be a Symbol + @test get(jlso, "str", 999) == 999 + @test_throws MethodError get!(jlso, "str", 999) +end