Skip to content

Commit

Permalink
Add tests for find_decompressed_size with skippable frames (#68)
Browse files Browse the repository at this point in the history
* Add tests for `find_decompressed_size` with skippable frames

* add `create_skippable_frame` test util
  • Loading branch information
nhz2 authored Sep 15, 2024
1 parent bcebbb0 commit cef9aaa
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 12 deletions.
25 changes: 19 additions & 6 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ using Test

Random.seed!(1234)

include("utils.jl")

@testset "Zstd Codec" begin
codec = ZstdCompressor()
@test codec isa ZstdCompressor
Expand Down Expand Up @@ -42,6 +44,17 @@ Random.seed!(1234)
end
end

@testset "skippable frames" begin
skippable_frame = create_skippable_frame(b"\r\0\0\0")
u1 = collect(b"")
u2 = collect(b"Hello World!")
c1 = transcode(ZstdCompressor, u1)
c2 = transcode(ZstdCompressor, u2)
@test transcode(ZstdDecompressor, skippable_frame) == UInt8[]
@test transcode(ZstdDecompressor, [skippable_frame; c1;]) == u1
@test transcode(ZstdDecompressor, [skippable_frame; c2;]) == u2
end

@test ZstdCompressorStream <: TranscodingStreams.TranscodingStream
@test ZstdDecompressorStream <: TranscodingStreams.TranscodingStream

Expand Down Expand Up @@ -91,9 +104,9 @@ Random.seed!(1234)
end

@testset "find_decompressed_size" begin
codec = ZstdFrameCompressor()
buffer1 = transcode(codec, "Hello")
buffer2 = transcode(codec, "World!")
codec = ZstdFrameCompressor
buffer1 = transcode(codec, b"Hello")
buffer2 = transcode(codec, b"World!")
@test CodecZstd.find_decompressed_size(buffer1) == 5
@test CodecZstd.find_decompressed_size(buffer2) == 6

Expand All @@ -116,9 +129,9 @@ Random.seed!(1234)
v = take!(iob)
@test CodecZstd.find_decompressed_size(v) == 22

codec = ZstdCompressor()
buffer3 = transcode(codec, "Hello")
buffer4 = transcode(codec, "World!")
codec = ZstdCompressor
buffer3 = transcode(codec, b"Hello")
buffer4 = transcode(codec, b"World!")
@test CodecZstd.find_decompressed_size(buffer3) == CodecZstd.ZSTD_CONTENTSIZE_UNKNOWN
@test CodecZstd.find_decompressed_size(buffer4) == CodecZstd.ZSTD_CONTENTSIZE_UNKNOWN

Expand Down
25 changes: 19 additions & 6 deletions test/static_only_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,12 @@ These tests use the static-only API to test `find_decompressed_size`
"""

@testset "find_decompressed_size (with static-only API)" begin
codec = ZstdFrameCompressor()
buffer1 = transcode(codec, "Hello")
buffer2 = transcode(codec, "World!")
codec = ZstdFrameCompressor
buffer1 = transcode(codec, b"Hello")
buffer2 = transcode(codec, b"World!")
LibZstdStatic.ZSTD_findDecompressedSize(b::Vector{UInt8}) = LibZstdStatic.ZSTD_findDecompressedSize(b, length(b))
@test CodecZstd.find_decompressed_size(buffer1) == LibZstdStatic.ZSTD_findDecompressedSize(buffer1)
@test CodecZstd.find_decompressed_size(buffer1) == 5
@test CodecZstd.find_decompressed_size(buffer2) == LibZstdStatic.ZSTD_findDecompressedSize(buffer2)

iob = IOBuffer()
Expand All @@ -34,9 +35,9 @@ These tests use the static-only API to test `find_decompressed_size`
v = take!(iob)
@test CodecZstd.find_decompressed_size(v) == LibZstdStatic.ZSTD_findDecompressedSize(v)

codec = ZstdCompressor()
buffer3 = transcode(codec, "Hello")
buffer4 = transcode(codec, "World!")
codec = ZstdCompressor
buffer3 = transcode(codec, b"Hello")
buffer4 = transcode(codec, b"World!")
@test CodecZstd.find_decompressed_size(buffer3) == LibZstdStatic.ZSTD_findDecompressedSize(buffer3)
@test CodecZstd.find_decompressed_size(buffer4) == LibZstdStatic.ZSTD_findDecompressedSize(buffer4)

Expand All @@ -60,4 +61,16 @@ These tests use the static-only API to test `find_decompressed_size`
@test CodecZstd.find_decompressed_size(pointer(v), length(buffer1)) == LibZstdStatic.ZSTD_findDecompressedSize(v, length(buffer1))
end
@test CodecZstd.find_decompressed_size(v) == LibZstdStatic.ZSTD_findDecompressedSize(v)

@testset "skippable frames" begin
skippable_frame = create_skippable_frame(b"\r\0\0\0")
@test CodecZstd.find_decompressed_size(skippable_frame) == LibZstdStatic.ZSTD_findDecompressedSize(skippable_frame)
@test CodecZstd.find_decompressed_size(skippable_frame) == 0
for d in 0:2
v = vcat(circshift([buffer1, skippable_frame, buffer2], d)...)
@test CodecZstd.find_decompressed_size(v) == LibZstdStatic.ZSTD_findDecompressedSize(v)
@test CodecZstd.find_decompressed_size(v) == 11
end
end

end
13 changes: 13 additions & 0 deletions test/utils.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
"""
create_skippable_frame(user_data::AbstractVector{UInt8}, magic_number::UInt32=0x184D2A50)::Vector{UInt8}
Return a skippable frame containing `user_data`.
"""
function create_skippable_frame(user_data::AbstractVector{UInt8}, magic_number::UInt32=0x184D2A50)
@assert magic_number 0x184D2A50:0x184D2A5F
UInt8[
reinterpret(UInt8, [htol(magic_number)]);
reinterpret(UInt8, [htol(UInt32(length(user_data)))]);
user_data;
]
end

0 comments on commit cef9aaa

Please sign in to comment.