diff --git a/src/CodecZstd.jl b/src/CodecZstd.jl index 315f5f2..c447c07 100644 --- a/src/CodecZstd.jl +++ b/src/CodecZstd.jl @@ -4,7 +4,10 @@ export ZstdCompressor, ZstdCompressorStream, ZstdDecompressor, - ZstdDecompressorStream + ZstdDecompressorStream, + ZstdFrameCompressor, + ZstdFrameDecompressor, + ZstdError import TranscodingStreams: TranscodingStreams, @@ -23,5 +26,7 @@ include("LibZstd_clang.jl") include("libzstd.jl") include("compression.jl") include("decompression.jl") +include("frameCompression.jl") +include("frameDecompression.jl") end # module diff --git a/src/compression.jl b/src/compression.jl index 36b93a4..a00187f 100644 --- a/src/compression.jl +++ b/src/compression.jl @@ -11,20 +11,21 @@ function Base.show(io::IO, codec::ZstdCompressor) end # Same as the zstd command line tool (v1.2.0). -const DEFAULT_COMPRESSION_LEVEL = 3 +const DEFAULT_COMPRESSION_LEVEL = DEFAULT_CLEVEL """ ZstdCompressor(;level=$(DEFAULT_COMPRESSION_LEVEL)) -Create a new zstd compression codec. +Create a new zstd compression codec using the streaming API. +This compressor uses `ZSTD_compressStream`. Arguments --------- -- `level`: compression level (1..$(MAX_CLEVEL)) +- `level`: compression level ($(MIN_CLEVEL)..$(MAX_CLEVEL)) """ function ZstdCompressor(;level::Integer=DEFAULT_COMPRESSION_LEVEL) - if !(1 ≤ level ≤ MAX_CLEVEL) - throw(ArgumentError("level must be within 1..$(MAX_CLEVEL)")) + if !(MIN_CLEVEL ≤ level ≤ MAX_CLEVEL) + throw(ArgumentError("level must be within $(MIN_CLEVEL)..$(MAX_CLEVEL)")) end return ZstdCompressor(CStream(), level) end diff --git a/src/decompression.jl b/src/decompression.jl index 6767634..8b674d4 100644 --- a/src/decompression.jl +++ b/src/decompression.jl @@ -12,7 +12,8 @@ end """ ZstdDecompressor() -Create a new zstd decompression codec. +Create a new zstd decompression codec using the streaming API. +This decompressor uses `ZSTD_decompressStream`. """ function ZstdDecompressor() return ZstdDecompressor(DStream()) diff --git a/src/frameCompression.jl b/src/frameCompression.jl new file mode 100644 index 0000000..dd9cf83 --- /dev/null +++ b/src/frameCompression.jl @@ -0,0 +1,101 @@ +# Frame Compressor Codec +# ====================== + +struct ZstdFrameCompressor <: TranscodingStreams.Codec + cstream::CStream + level::Int +end + +function Base.show(io::IO, codec::ZstdFrameCompressor) + print(io, summary(codec), "(level=$(codec.level))") +end + +# See compressor.jl for DEFAULT_COMPRESSION_LEVEL + +""" + ZstdFrameCompressor(;level=$(DEFAULT_COMPRESSION_LEVEL)) + +Create a new zstd compression codec using the non-streaming API. +This is uses `ZSTD_compress2`. This compressor expects to have the +entire input buffer to be compressed available and stores the +decompressed length in the frame header. + +Arguments +--------- +- `level`: compression level ($(MIN_CLEVEL)..$(MAX_CLEVEL)) +""" +function ZstdFrameCompressor(;level::Integer=DEFAULT_COMPRESSION_LEVEL) + if !(MIN_CLEVEL ≤ level ≤ MAX_CLEVEL) + throw(ArgumentError("level must be within $(MIN_CLEVEL)..$(MAX_CLEVEL)")) + end + return ZstdFrameCompressor(CStream(), level) +end + +const ZstdFrameCompressorStream{S} = TranscodingStream{ZstdFrameCompressor,S} where S<:IO + +""" + ZstdFrameCompressorStream(stream::IO; kwargs...) + +Create a new zstd compression stream (see `ZstdFrameCompressor` for `kwargs`). +""" +function ZstdFrameCompressorStream(stream::IO; kwargs...) + x, y = splitkwargs(kwargs, (:level,)) + return TranscodingStream(ZstdFrameCompressor(;x...), stream; y...) +end + + +# Methods +# ------- + +function TranscodingStreams.initialize(codec::ZstdFrameCompressor) + code = initialize!(codec.cstream, codec.level) + if iserror(code) + throw(ZstdError(code)) + end + return +end + +function TranscodingStreams.finalize(codec::ZstdFrameCompressor) + if codec.cstream.ptr != C_NULL + code = free!(codec.cstream) + if iserror(code) + throw(ZstdError(code)) + end + codec.cstream.ptr = C_NULL + end + return +end + +function TranscodingStreams.expectedsize(codec::ZstdFrameCompressor, input::Memory) + code = compressed_size_bound(input.size) + if iserror(code) + throw(ZstdError(code)) + end + return Int(code) +end + +function TranscodingStreams.startproc(codec::ZstdFrameCompressor, mode::Symbol, error::Error) + code = reset!(codec.cstream, 0 #=unknown source size=#) + if iserror(code) + error[] = ZstdError(code) + return :error + end + return :ok +end + +function TranscodingStreams.process(codec::ZstdFrameCompressor, input::Memory, output::Memory, error::Error) + cstream = codec.cstream + cstream.ibuffer.src = input.ptr + cstream.ibuffer.size = input.size + cstream.ibuffer.pos = 0 + cstream.obuffer.dst = output.ptr + cstream.obuffer.size = output.size + cstream.obuffer.pos = 0 + code = frameCompress!(cstream) + if iserror(code) + error[] = ZstdError(code) + return 0, 0, :error + else + return Int(input.size), Int(code), :end + end +end diff --git a/src/frameDecompression.jl b/src/frameDecompression.jl new file mode 100644 index 0000000..7502ad1 --- /dev/null +++ b/src/frameDecompression.jl @@ -0,0 +1,92 @@ +# Decompressor Codec +# ================== + +struct ZstdFrameDecompressor <: TranscodingStreams.Codec + dstream::DStream +end + +function Base.show(io::IO, codec::ZstdFrameDecompressor) + print(io, summary(codec), "()") +end + +""" + ZstdFrameDecompressor() + +Create a new zstd decompression codec. +This decompressor uses the non-streaming API, expecting a known length, via `ZSTD_decompressDCtx` +""" +function ZstdFrameDecompressor() + return ZstdFrameDecompressor(DStream()) +end + +const ZstdFrameDecompressorStream{S} = TranscodingStream{ZstdFrameDecompressor,S} where S<:IO + +""" + ZstdFrameDecompressorStream(stream::IO; kwargs...) + +Create a new zstd decompression stream (`kwargs` are passed to `TranscodingStream`). +""" +function ZstdFrameDecompressorStream(stream::IO; kwargs...) + return TranscodingStream(ZstdFrameDecompressor(), stream; kwargs...) +end + + +# Methods +# ------- + +function TranscodingStreams.initialize(codec::ZstdFrameDecompressor) + code = initialize!(codec.dstream) + if iserror(code) + throw(ZstdError(code)) + end + return +end + +function TranscodingStreams.finalize(codec::ZstdFrameDecompressor) + if codec.dstream.ptr != C_NULL + code = free!(codec.dstream) + if iserror(code) + throw(ZstdError(code)) + end + codec.dstream.ptr = C_NULL + end + return +end + +function TranscodingStreams.startproc(codec::ZstdFrameDecompressor, mode::Symbol, error::Error) + code = reset!(codec.dstream) + if iserror(code) + error[] = ZstdError(code) + return :error + end + return :ok +end + +function TranscodingStreams.process(codec::ZstdFrameDecompressor, input::Memory, output::Memory, error::Error) + dstream = codec.dstream + dstream.ibuffer.src = input.ptr + dstream.ibuffer.size = input.size + dstream.ibuffer.pos = 0 + dstream.obuffer.dst = output.ptr + dstream.obuffer.size = output.size + dstream.obuffer.pos = 0 + code = frameDecompress!(dstream) + if iserror(code) + error[] = ZstdError(code) + return 0, 0, :error + else + return Int(input.size), Int(code), :end + end +end + +function TranscodingStreams.expectedsize(codec::ZstdFrameDecompressor, input::Memory) + ret = find_decompressed_size(input.ptr, input.size) + if ret == ZSTD_CONTENTSIZE_ERROR + throw(ZstdError()) + elseif ret == ZSTD_CONTENTSIZE_UNKNOWN + return Int(decompressed_size_bound(input.ptr, input.size)) + else + # exact size + return Int(ret) + end +end diff --git a/src/libzstd.jl b/src/libzstd.jl index 9906b2b..874c279 100644 --- a/src/libzstd.jl +++ b/src/libzstd.jl @@ -6,15 +6,34 @@ function iserror(code::Csize_t) end function zstderror(stream, code::Csize_t) + zstderror(code) +end +function zstderror(code::Csize_t) ptr = LibZstd.ZSTD_getErrorName(code) error("zstd error: ", unsafe_string(ptr)) end +struct ZstdError <: Exception + code::Csize_t +end +ZstdError() = ZstdError(typemax(Csize_t)) +function Base.show(io::IO, e::ZstdError) + print(io, "ZstdError: ", unsafe_string(LibZstd.ZSTD_getErrorName(e.code))) +end + function max_clevel() return LibZstd.ZSTD_maxCLevel() end +function min_clevel() + return LibZstd.ZSTD_minCLevel() +end +function default_clevel() + return LibZstd.ZSTD_defaultCLevel() +end const MAX_CLEVEL = max_clevel() +const MIN_CLEVEL = min_clevel() +const DEFAULT_CLEVEL = default_clevel() const InBuffer = LibZstd.ZSTD_inBuffer InBuffer() = InBuffer(C_NULL, 0, 0) @@ -98,6 +117,14 @@ function Base.convert(::Type{LibZstd.ZSTD_EndDirective}, endOp::Symbol) return endOp end +function frameCompress!(cstream::CStream) + return LibZstd.ZSTD_compress2( + cstream, + cstream.obuffer.dst, cstream.obuffer.size, + cstream.ibuffer.src, cstream.ibuffer.size + ) +end + function finish!(cstream::CStream) return LibZstd.ZSTD_endStream(cstream, cstream.obuffer) end @@ -138,6 +165,14 @@ function decompress!(dstream::DStream) return LibZstd.ZSTD_decompressStream(dstream, dstream.obuffer, dstream.ibuffer) end +function frameDecompress!(dstream::DStream) + return LibZstd.ZSTD_decompressDCtx( + dstream, + dstream.obuffer.dst, dstream.obuffer.size, + dstream.ibuffer.src, dstream.ibuffer.size + ) +end + function free!(dstream::DStream) return LibZstd.ZSTD_freeDStream(dstream) end @@ -152,3 +187,15 @@ const ZSTD_CONTENTSIZE_ERROR = Culonglong(0) - 2 function find_decompressed_size(src::Ptr, size::Integer) return LibZstd.ZSTD_findDecompressedSize(src, size) end +function find_decompressed_size(src::Vector{UInt8}) + return LibZstd.ZSTD_findDecompressedSize(src, sizeof(src)) +end +function compressed_size_bound(sz) + return LibZstd.ZSTD_compressBound(sz) +end +function decompressed_size_bound(src::Ptr, size::Integer) + return LibZstd.ZSTD_decompressBound(src, size) +end +function decompressed_size_bound(src::Vector{UInt8}) + return LibZstd.ZSTD_decompressBound(src, sizeof(src)) +end diff --git a/test/runtests.jl b/test/runtests.jl index cdb1f64..4dc80a0 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -18,6 +18,18 @@ Random.seed!(1234) @test CodecZstd.initialize(codec) === nothing @test CodecZstd.finalize(codec) === nothing + codec = ZstdFrameCompressor() + @test codec isa ZstdFrameCompressor + @test occursin(r"^ZstdFrameCompressor\(level=\d+\)$", sprint(show, codec)) + @test CodecZstd.initialize(codec) === nothing + @test CodecZstd.finalize(codec) === nothing + + codec = ZstdFrameDecompressor() + @test codec isa ZstdFrameDecompressor + @test occursin(r"^ZstdFrameDecompressor\(\)$", sprint(show, codec)) + @test CodecZstd.initialize(codec) === nothing + @test CodecZstd.finalize(codec) === nothing + data = [0x28, 0xb5, 0x2f, 0xfd, 0x04, 0x50, 0x19, 0x00, 0x00, 0x66, 0x6f, 0x6f, 0x3f, 0xba, 0xc4, 0x59] @test read(ZstdDecompressorStream(IOBuffer(data))) == b"foo" @test read(ZstdDecompressorStream(IOBuffer(vcat(data, data)))) == b"foofoo" @@ -43,6 +55,15 @@ Random.seed!(1234) TranscodingStreams.test_roundtrip_write(ZstdCompressorStream, ZstdDecompressorStream) TranscodingStreams.test_roundtrip_lines(ZstdCompressorStream, ZstdDecompressorStream) TranscodingStreams.test_roundtrip_transcode(ZstdCompressor, ZstdDecompressor) + TranscodingStreams.test_roundtrip_transcode(ZstdFrameCompressor, ZstdDecompressor) + TranscodingStreams.test_roundtrip_transcode(ZstdFrameCompressor, ZstdFrameDecompressor) + TranscodingStreams.test_roundtrip_transcode(ZstdCompressor, ZstdFrameDecompressor) + + @static if VERSION ≥ v"1.8" + @test_throws "ZstdError: Destination buffer is too small" throw(ZstdError(0xffffffffffffffba)) + else + @test_throws ZstdError throw(ZstdError(0xffffffffffffffba)) + end include("compress_endOp.jl") end