Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Auto initialize in startproc #74

Merged
merged 6 commits into from
Oct 3, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 16 additions & 13 deletions src/compression.jl
Original file line number Diff line number Diff line change
Expand Up @@ -78,16 +78,6 @@
# Methods
# -------

function TranscodingStreams.initialize(codec::ZstdCompressor)
code = initialize!(codec.cstream, codec.level)
if iserror(code)
zstderror(codec.cstream, code)
end
reset!(codec.cstream.ibuffer)
reset!(codec.cstream.obuffer)
Comment on lines -86 to -87
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where does this happen now?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This happens in reset!

reset!(cstream.ibuffer)
reset!(cstream.obuffer)

Which is called in startproc

code = reset!(codec.cstream, 0 #=unknown source size=#)

return
end

function TranscodingStreams.finalize(codec::ZstdCompressor)
if codec.cstream.ptr != C_NULL
code = free!(codec.cstream)
Expand All @@ -96,12 +86,22 @@
end
codec.cstream.ptr = C_NULL
end
reset!(codec.cstream.ibuffer)
reset!(codec.cstream.obuffer)
return
nothing
nhz2 marked this conversation as resolved.
Show resolved Hide resolved
nhz2 marked this conversation as resolved.
Show resolved Hide resolved
end

function TranscodingStreams.startproc(codec::ZstdCompressor, mode::Symbol, error::Error)
if codec.cstream.ptr == C_NULL
ptr = LibZstd.ZSTD_createCStream()
if ptr == C_NULL
throw(OutOfMemoryError())

Check warning on line 96 in src/compression.jl

View check run for this annotation

Codecov / codecov/patch

src/compression.jl#L96

Added line #L96 was not covered by tests
end
codec.cstream.ptr = ptr
nhz2 marked this conversation as resolved.
Show resolved Hide resolved
i_code = initialize!(codec.cstream, codec.level)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also see notes in #73

Should initialize! throw so we can catch it here and transmit the error?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, if initialize! is changed to throw, then this needs to catch that error and return :error, but for now initialize! returns an error code on failure.

if iserror(i_code)
error[] = ErrorException("zstd initialization error")
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These errors are unreachable unless there is some allocation error.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you mock the out of memory condition them by using that advanced API that provides the memory allocation functions?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if even that would reliably trigger an error specifically here, because memory allocations are happening in ZSTD_createCStream.

return :error

Check warning on line 102 in src/compression.jl

View check run for this annotation

Codecov / codecov/patch

src/compression.jl#L101-L102

Added lines #L101 - L102 were not covered by tests
end
end
code = reset!(codec.cstream, 0 #=unknown source size=#)
if iserror(code)
error[] = ErrorException("zstd error")
Expand All @@ -111,6 +111,9 @@
end

function TranscodingStreams.process(codec::ZstdCompressor, input::Memory, output::Memory, error::Error)
if codec.cstream.ptr == C_NULL
error("startproc must be called before process")
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This error should also be unreachable in normal operation.

end
cstream = codec.cstream
ibuffer_starting_pos = UInt(0)
if codec.endOp == LibZstd.ZSTD_e_end &&
Expand Down
29 changes: 16 additions & 13 deletions src/decompression.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,16 +33,6 @@
# Methods
# -------

function TranscodingStreams.initialize(codec::ZstdDecompressor)
code = initialize!(codec.dstream)
if iserror(code)
zstderror(codec.dstream, code)
end
reset!(codec.dstream.ibuffer)
reset!(codec.dstream.obuffer)
return
end

function TranscodingStreams.finalize(codec::ZstdDecompressor)
if codec.dstream.ptr != C_NULL
code = free!(codec.dstream)
Expand All @@ -51,12 +41,22 @@
end
codec.dstream.ptr = C_NULL
end
reset!(codec.dstream.ibuffer)
reset!(codec.dstream.obuffer)
return
nothing
end

function TranscodingStreams.startproc(codec::ZstdDecompressor, mode::Symbol, error::Error)
if codec.dstream.ptr == C_NULL
ptr = LibZstd.ZSTD_createDStream()
if ptr == C_NULL
throw(OutOfMemoryError())

Check warning on line 51 in src/decompression.jl

View check run for this annotation

Codecov / codecov/patch

src/decompression.jl#L51

Added line #L51 was not covered by tests
end
codec.dstream.ptr = ptr
nhz2 marked this conversation as resolved.
Show resolved Hide resolved
i_code = initialize!(codec.dstream)
if iserror(i_code)
error[] = ErrorException("zstd initialization error")
return :error

Check warning on line 57 in src/decompression.jl

View check run for this annotation

Codecov / codecov/patch

src/decompression.jl#L56-L57

Added lines #L56 - L57 were not covered by tests
end
end
code = reset!(codec.dstream)
if iserror(code)
error[] = ErrorException("zstd error")
Expand All @@ -66,6 +66,9 @@
end

function TranscodingStreams.process(codec::ZstdDecompressor, input::Memory, output::Memory, error::Error)
if codec.dstream.ptr == C_NULL
error("startproc must be called before process")
end
dstream = codec.dstream
dstream.ibuffer.src = input.ptr
dstream.ibuffer.size = input.size
Expand Down
12 changes: 2 additions & 10 deletions src/libzstd.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,7 @@ mutable struct CStream
obuffer::OutBuffer

function CStream()
ptr = LibZstd.ZSTD_createCStream()
if ptr == C_NULL
throw(OutOfMemoryError())
end
return new(ptr, InBuffer(), OutBuffer())
return new(C_NULL, InBuffer(), OutBuffer())
end
end

Expand Down Expand Up @@ -127,11 +123,7 @@ mutable struct DStream
obuffer::OutBuffer

function DStream()
ptr = LibZstd.ZSTD_createDStream()
if ptr == C_NULL
throw(OutOfMemoryError())
end
return new(ptr, InBuffer(), OutBuffer())
return new(C_NULL, InBuffer(), OutBuffer())
end
end
Base.unsafe_convert(::Type{Ptr{LibZstd.ZSTD_DStream}}, dstream::DStream) = dstream.ptr
Expand Down
3 changes: 3 additions & 0 deletions test/compress_endOp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ using Test
@testset "compress! endOp = :continue" begin
data = rand(1:100, 1024*1024)
nhz2 marked this conversation as resolved.
Show resolved Hide resolved
cstream = CodecZstd.CStream()
cstream.ptr = CodecZstd.LibZstd.ZSTD_createCStream()
cstream.ibuffer.src = pointer(data)
cstream.ibuffer.size = sizeof(data)
cstream.ibuffer.pos = 0
Expand All @@ -24,6 +25,7 @@ end
@testset "compress! endOp = :flush" begin
data = rand(1:100, 1024*1024)
cstream = CodecZstd.CStream()
cstream.ptr = CodecZstd.LibZstd.ZSTD_createCStream()
cstream.ibuffer.src = pointer(data)
cstream.ibuffer.size = sizeof(data)
cstream.ibuffer.pos = 0
Expand All @@ -43,6 +45,7 @@ end
@testset "compress! endOp = :end" begin
data = rand(1:100, 1024*1024)
cstream = CodecZstd.CStream()
cstream.ptr = CodecZstd.LibZstd.ZSTD_createCStream()
cstream.ibuffer.src = pointer(data)
cstream.ibuffer.size = sizeof(data)
cstream.ibuffer.pos = 0
Expand Down
68 changes: 68 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -158,4 +158,72 @@ include("utils.jl")

include("compress_endOp.jl")
include("static_only_tests.jl")

@testset "reusing a compressor" begin
compressor = ZstdCompressor()
x = rand(UInt8, 1000)
TranscodingStreams.initialize(compressor)
ret1 = transcode(compressor, x)
TranscodingStreams.finalize(compressor)

# compress again using the same compressor
TranscodingStreams.initialize(compressor) # segfault happens here!
ret2 = transcode(compressor, x)
ret3 = transcode(compressor, x)
TranscodingStreams.finalize(compressor)

@test transcode(ZstdDecompressor, ret1) == x
@test transcode(ZstdDecompressor, ret2) == x
@test transcode(ZstdDecompressor, ret3) == x
@test ret1 == ret2
@test ret1 == ret3

decompressor = ZstdDecompressor()
TranscodingStreams.initialize(decompressor)
@test transcode(decompressor, ret1) == x
TranscodingStreams.finalize(decompressor)

TranscodingStreams.initialize(decompressor)
@test transcode(decompressor, ret1) == x
TranscodingStreams.finalize(decompressor)
end

@testset "use after free doesn't segfault" begin
@testset "$(Codec)" for Codec in (ZstdCompressor, ZstdDecompressor)
codec = Codec()
TranscodingStreams.initialize(codec)
TranscodingStreams.finalize(codec)
data = [0x00,0x01]
GC.@preserve data let m = TranscodingStreams.Memory(pointer(data), length(data))
try
TranscodingStreams.expectedsize(codec, m)
catch
end
try
nhz2 marked this conversation as resolved.
Show resolved Hide resolved
TranscodingStreams.minoutsize(codec, m)
catch
end
try
TranscodingStreams.initialize(codec)
catch
end
try
TranscodingStreams.process(codec, m, m, TranscodingStreams.Error())
catch
end
try
TranscodingStreams.startproc(codec, :read, TranscodingStreams.Error())
catch
end
try
TranscodingStreams.process(codec, m, m, TranscodingStreams.Error())
catch
end
try
TranscodingStreams.finalize(codec)
catch
end
end
end
end
end
Loading