diff --git a/src/compression.jl b/src/compression.jl index cabc3f9..fc9361d 100644 --- a/src/compression.jl +++ b/src/compression.jl @@ -78,16 +78,6 @@ end # Methods # ------- -function TranscodingStreams.initialize(codec::ZstdCompressor) - code = initialize!(codec.cstream, codec.level) - if iserror(code) - zstderror(codec.cstream, code) - end - reset!(codec.cstream.ibuffer) - reset!(codec.cstream.obuffer) - return -end - function TranscodingStreams.finalize(codec::ZstdCompressor) if codec.cstream.ptr != C_NULL code = free!(codec.cstream) @@ -96,12 +86,21 @@ function TranscodingStreams.finalize(codec::ZstdCompressor) end codec.cstream.ptr = C_NULL end - reset!(codec.cstream.ibuffer) - reset!(codec.cstream.obuffer) return end function TranscodingStreams.startproc(codec::ZstdCompressor, mode::Symbol, error::Error) + if codec.cstream.ptr == C_NULL + codec.cstream.ptr = LibZstd.ZSTD_createCStream() + if codec.cstream.ptr == C_NULL + throw(OutOfMemoryError()) + end + i_code = initialize!(codec.cstream, codec.level) + if iserror(i_code) + error[] = ErrorException("zstd initialization error") + return :error + end + end code = reset!(codec.cstream, 0 #=unknown source size=#) if iserror(code) error[] = ErrorException("zstd error") @@ -111,6 +110,9 @@ function TranscodingStreams.startproc(codec::ZstdCompressor, mode::Symbol, error end function TranscodingStreams.process(codec::ZstdCompressor, input::Memory, output::Memory, error::Error) + if codec.cstream.ptr == C_NULL + error("startproc must be called before process") + end cstream = codec.cstream ibuffer_starting_pos = UInt(0) if codec.endOp == LibZstd.ZSTD_e_end && diff --git a/src/decompression.jl b/src/decompression.jl index 765ce2c..7ed15cb 100644 --- a/src/decompression.jl +++ b/src/decompression.jl @@ -33,16 +33,6 @@ end # Methods # ------- -function TranscodingStreams.initialize(codec::ZstdDecompressor) - code = initialize!(codec.dstream) - if iserror(code) - zstderror(codec.dstream, code) - end - reset!(codec.dstream.ibuffer) - reset!(codec.dstream.obuffer) - return -end - function TranscodingStreams.finalize(codec::ZstdDecompressor) if codec.dstream.ptr != C_NULL code = free!(codec.dstream) @@ -51,12 +41,21 @@ function TranscodingStreams.finalize(codec::ZstdDecompressor) end codec.dstream.ptr = C_NULL end - reset!(codec.dstream.ibuffer) - reset!(codec.dstream.obuffer) return end function TranscodingStreams.startproc(codec::ZstdDecompressor, mode::Symbol, error::Error) + if codec.dstream.ptr == C_NULL + codec.dstream.ptr = LibZstd.ZSTD_createDStream() + if codec.dstream.ptr == C_NULL + throw(OutOfMemoryError()) + end + i_code = initialize!(codec.dstream) + if iserror(i_code) + error[] = ErrorException("zstd initialization error") + return :error + end + end code = reset!(codec.dstream) if iserror(code) error[] = ErrorException("zstd error") @@ -66,6 +65,9 @@ function TranscodingStreams.startproc(codec::ZstdDecompressor, mode::Symbol, err end function TranscodingStreams.process(codec::ZstdDecompressor, input::Memory, output::Memory, error::Error) + if codec.dstream.ptr == C_NULL + error("startproc must be called before process") + end dstream = codec.dstream dstream.ibuffer.src = input.ptr dstream.ibuffer.size = input.size diff --git a/src/libzstd.jl b/src/libzstd.jl index c11b1f1..f9865b6 100644 --- a/src/libzstd.jl +++ b/src/libzstd.jl @@ -44,11 +44,7 @@ mutable struct CStream obuffer::OutBuffer function CStream() - ptr = LibZstd.ZSTD_createCStream() - if ptr == C_NULL - throw(OutOfMemoryError()) - end - return new(ptr, InBuffer(), OutBuffer()) + return new(C_NULL, InBuffer(), OutBuffer()) end end @@ -127,11 +123,7 @@ mutable struct DStream obuffer::OutBuffer function DStream() - ptr = LibZstd.ZSTD_createDStream() - if ptr == C_NULL - throw(OutOfMemoryError()) - end - return new(ptr, InBuffer(), OutBuffer()) + return new(C_NULL, InBuffer(), OutBuffer()) end end Base.unsafe_convert(::Type{Ptr{LibZstd.ZSTD_DStream}}, dstream::DStream) = dstream.ptr @@ -145,6 +137,8 @@ end function reset!(dstream::DStream) # LibZstd.ZSTD_resetDStream is deprecated # https://github.com/facebook/zstd/blob/9d2a45a705e22ad4817b41442949cd0f78597154/lib/zstd.h#L2332-L2339 + reset!(dstream.ibuffer) + reset!(dstream.obuffer) return LibZstd.ZSTD_DCtx_reset(dstream, LibZstd.ZSTD_reset_session_only) end diff --git a/test/compress_endOp.jl b/test/compress_endOp.jl index 0594f1f..f5f120d 100644 --- a/test/compress_endOp.jl +++ b/test/compress_endOp.jl @@ -3,27 +3,29 @@ using Test @testset "compress! endOp = :continue" begin data = rand(1:100, 1024*1024) - cstream = CodecZstd.CStream() - cstream.ibuffer.src = pointer(data) - cstream.ibuffer.size = sizeof(data) - cstream.ibuffer.pos = 0 - cstream.obuffer.dst = Base.Libc.malloc(sizeof(data)*2) - cstream.obuffer.size = sizeof(data)*2 - cstream.obuffer.pos = 0 - try - GC.@preserve data begin + GC.@preserve data begin + cstream = CodecZstd.CStream() + cstream.ptr = CodecZstd.LibZstd.ZSTD_createCStream() + cstream.ibuffer.src = pointer(data) + cstream.ibuffer.size = sizeof(data) + cstream.ibuffer.pos = 0 + cstream.obuffer.dst = Base.Libc.malloc(sizeof(data)*2) + cstream.obuffer.size = sizeof(data)*2 + cstream.obuffer.pos = 0 + try # default endOp @test CodecZstd.compress!(cstream; endOp=:continue) == 0 @test CodecZstd.find_decompressed_size(cstream.obuffer.dst, cstream.obuffer.pos) == CodecZstd.ZSTD_CONTENTSIZE_UNKNOWN + finally + Base.Libc.free(cstream.obuffer.dst) end - finally - Base.Libc.free(cstream.obuffer.dst) end end @testset "compress! endOp = :flush" begin data = rand(1:100, 1024*1024) cstream = CodecZstd.CStream() + cstream.ptr = CodecZstd.LibZstd.ZSTD_createCStream() cstream.ibuffer.src = pointer(data) cstream.ibuffer.size = sizeof(data) cstream.ibuffer.pos = 0 @@ -43,6 +45,7 @@ end @testset "compress! endOp = :end" begin data = rand(1:100, 1024*1024) cstream = CodecZstd.CStream() + cstream.ptr = CodecZstd.LibZstd.ZSTD_createCStream() cstream.ibuffer.src = pointer(data) cstream.ibuffer.size = sizeof(data) cstream.ibuffer.pos = 0 diff --git a/test/runtests.jl b/test/runtests.jl index a111d9a..73cf82f 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -158,4 +158,72 @@ include("utils.jl") include("compress_endOp.jl") include("static_only_tests.jl") + + @testset "reusing a compressor" begin + compressor = ZstdCompressor() + x = rand(UInt8, 1000) + TranscodingStreams.initialize(compressor) + ret1 = transcode(compressor, x) + TranscodingStreams.finalize(compressor) + + # compress again using the same compressor + TranscodingStreams.initialize(compressor) # segfault happens here! + ret2 = transcode(compressor, x) + ret3 = transcode(compressor, x) + TranscodingStreams.finalize(compressor) + + @test transcode(ZstdDecompressor, ret1) == x + @test transcode(ZstdDecompressor, ret2) == x + @test transcode(ZstdDecompressor, ret3) == x + @test ret1 == ret2 + @test ret1 == ret3 + + decompressor = ZstdDecompressor() + TranscodingStreams.initialize(decompressor) + @test transcode(decompressor, ret1) == x + TranscodingStreams.finalize(decompressor) + + TranscodingStreams.initialize(decompressor) + @test transcode(decompressor, ret1) == x + TranscodingStreams.finalize(decompressor) + end + + @testset "use after free doesn't segfault" begin + @testset "$(Codec)" for Codec in (ZstdCompressor, ZstdDecompressor) + codec = Codec() + TranscodingStreams.initialize(codec) + TranscodingStreams.finalize(codec) + data = [0x00,0x01] + GC.@preserve data let m = TranscodingStreams.Memory(pointer(data), length(data)) + try + TranscodingStreams.expectedsize(codec, m) + catch + end + try + TranscodingStreams.minoutsize(codec, m) + catch + end + try + TranscodingStreams.initialize(codec) + catch + end + try + TranscodingStreams.process(codec, m, m, TranscodingStreams.Error()) + catch + end + try + TranscodingStreams.startproc(codec, :read, TranscodingStreams.Error()) + catch + end + try + TranscodingStreams.process(codec, m, m, TranscodingStreams.Error()) + catch + end + try + TranscodingStreams.finalize(codec) + catch + end + end + end + end end