Skip to content

Commit

Permalink
Add more filters (#154)
Browse files Browse the repository at this point in the history
* Document the Filter interface

* Move filters to a folder

Same rationale as the other changes :D - just for cleanliness and clarity.

* Factor out variable-length filters to a new file

* Add docstrings to filter API functions

* Add a Fletcher32 filter and test

* re-add the dictionary entries for the vlen filters

* Semi-working fixed scale offset filter

* Add FixedScaleOffset tests

* Add shuffle filter (buggy in the last few bytes, indexing issues)

* WIP quantize filter

* ShuffleFilter working and tested

* Semi working quantize filter

* Format tests better

* Complete interface and test quantize

* Uncomment the FixedScaleOffset tests

* fix getfilter syntax

* Add delta filter

* Adapt for Kerchunk playing fast and loose with the spec

- Kerchunk often encodes the compressor as the last filter, so we check that the compressor isn't hiding in the filters array if the compressor is null.
- Similarly, the dtype is often unknown in this case, or the transform is not encoded correctly, so we ensure that the datatypes of `data` and `a2` remain the same by reinterpreting.

* Fix the delta and quantize JSON.lower

* Change the tests to be more sensible/Julian and avoid truncation errors

* Fix the FixedScaleOffset filter materializer

* Fix decoding for fill values to use `reinterpret` on unsigned -> integer

* If `getfilter` fails, show the filter name and then throw an error

* Apply reinterpret before multiplication in fixed-scale-offset filter

* Only reinterpret negative integers when decoding fill values to unsigned

* Revert "Only reinterpret negative integers when decoding fill values to unsigned"

This reverts commit 24a68e6.

* let Fletcher32 operate on n-dimensional arrays

not just vectors, as it was previously constrained to

* fix FixedScaleOffset in many ways

- Never use reinterpret
- use array comprehensions to support 0-dimensional arrays correctly, the performance impact is negligible based on testing
- only round if the target type is an integer, otherwise let it be if it's a float.

* add filter tests in Python

* Fix filter astype, id to conform to Python names

* remove encoding validity check for quantize - it's pointless
  • Loading branch information
asinghvi17 authored Dec 6, 2024
1 parent 2ae3c2a commit b52be51
Show file tree
Hide file tree
Showing 12 changed files with 640 additions and 42 deletions.
95 changes: 95 additions & 0 deletions src/Filters/Filters.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
import JSON

"""
abstract type Filter{T,TENC}
The supertype for all Zarr filters.
## Interface
All subtypes MUST implement the following methods:
- [`zencode(ain, filter::Filter)`](@ref zencode): Encodes data `ain` using the filter, and returns a vector of bytes.
- [`zdecode(ain, filter::Filter)`](@ref zdecode): Decodes data `ain`, a vector of bytes, using the filter, and returns the original data.
- [`JSON.lower`](@ref): Returns a JSON-serializable dictionary representing the filter, according to the Zarr specification.
- [`getfilter(::Type{<: Filter}, filterdict)`](@ref getfilter): Returns the filter type read from a given filter dictionary.
If the filter has type parameters, it MUST also implement:
- [`sourcetype(::Filter)::T`](@ref sourcetype): equivalent to `dtype` in the Python Zarr implementation.
- [`desttype(::Filter)::T`](@ref desttype): equivalent to `atype` in the Python Zarr implementation.
Finally, an entry MUST be added to the `filterdict` dictionary for each filter type.
This must also follow the Zarr specification's name for that filter. The name of the filter
is the key, and the value is the filter type (e.g. `VLenUInt8Filter` or `Fletcher32Filter`).
Subtypes include: [`VLenArrayFilter`](@ref), [`VLenUTF8Filter`](@ref), [`Fletcher32Filter`](@ref).
"""
abstract type Filter{T,TENC} end

"""
zencode(ain, filter::Filter)
Encodes data `ain` using the filter, and returns a vector of bytes.
"""
function zencode end

"""
zdecode(ain, filter::Filter)
Decodes data `ain`, a vector of bytes, using the filter, and returns the original data.
"""
function zdecode end

"""
getfilter(::Type{<: Filter}, filterdict)
Returns the filter type read from a given specification dictionary, which must follow the Zarr specification.
"""
function getfilter end

"""
sourcetype(::Filter)::T
Returns the source type of the filter.
"""
function sourcetype end

"""
desttype(::Filter)::T
Returns the destination type of the filter.
"""
function desttype end

filterdict = Dict{String,Type{<:Filter}}()

function getfilters(d::Dict)
if !haskey(d,"filters")
return nothing
else
if d["filters"] === nothing || isempty(d["filters"])
return nothing
end
f = map(d["filters"]) do f
try
getfilter(filterdict[f["id"]], f)
catch e
@show f
rethrow(e)
end
end
return (f...,)
end
end
sourcetype(::Filter{T}) where T = T
desttype(::Filter{<:Any,T}) where T = T

zencode(ain,::Nothing) = ain

include("vlenfilters.jl")
include("fletcher32.jl")
include("fixedscaleoffset.jl")
include("shuffle.jl")
include("quantize.jl")
include("delta.jl")
45 changes: 45 additions & 0 deletions src/Filters/delta.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
#=
# Delta compression
=#

"""
DeltaFilter(; DecodingType, [EncodingType = DecodingType])
Delta-based compression for Zarr arrays. (Delta encoding is Julia `diff`, decoding is Julia `cumsum`).
"""
struct DeltaFilter{T, TENC} <: Filter{T, TENC}
end

function DeltaFilter(; DecodingType = Float16, EncodingType = DecodingType)
return DeltaFilter{DecodingType, EncodingType}()
end

DeltaFilter{T}() where T = DeltaFilter{T, T}()

function zencode(data::AbstractArray, filter::DeltaFilter{DecodingType, EncodingType}) where {DecodingType, EncodingType}
arr = reinterpret(DecodingType, vec(data))

enc = similar(arr, EncodingType)
# perform the delta operation
enc[begin] = arr[begin]
enc[begin+1:end] .= diff(arr)
return enc
end

function zdecode(data::AbstractArray, filter::DeltaFilter{DecodingType, EncodingType}) where {DecodingType, EncodingType}
encoded = reinterpret(EncodingType, vec(data))
decoded = DecodingType.(cumsum(encoded))
return decoded
end

function JSON.lower(filter::DeltaFilter{T, Tenc}) where {T, Tenc}
return Dict("id" => "delta", "dtype" => typestr(T), "astype" => typestr(Tenc))
end

function getfilter(::Type{<: DeltaFilter}, d)
return DeltaFilter{typestr(d["dtype"], haskey(d, "astype") ? typestr(d["astype"]) : d["dtype"])}()
end

filterdict["delta"] = DeltaFilter
52 changes: 52 additions & 0 deletions src/Filters/fixedscaleoffset.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@

"""
FixedScaleOffsetFilter{T,TENC}(scale, offset)
A compressor that scales and offsets the data.
!!! note
The geographic CF standards define scale/offset decoding as `x * scale + offset`,
but this filter defines it as `x / scale + offset`. Constructing a `FixedScaleOffsetFilter`
from CF data means `FixedScaleOffsetFilter(1/cf_scale_factor, cf_add_offset)`.
"""
struct FixedScaleOffsetFilter{ScaleOffsetType, T, Tenc} <: Filter{T, Tenc}
scale::ScaleOffsetType
offset::ScaleOffsetType
end

FixedScaleOffsetFilter{T}(scale::ScaleOffsetType, offset::ScaleOffsetType) where {T, ScaleOffsetType} = FixedScaleOffsetFilter{T, ScaleOffsetType}(scale, offset)
FixedScaleOffsetFilter(scale::ScaleOffsetType, offset::ScaleOffsetType) where {ScaleOffsetType} = FixedScaleOffsetFilter{ScaleOffsetType, ScaleOffsetType}(scale, offset)

function FixedScaleOffsetFilter(; scale::ScaleOffsetType, offset::ScaleOffsetType, T, Tenc = T) where ScaleOffsetType
return FixedScaleOffsetFilter{ScaleOffsetType, T, Tenc}(scale, offset)
end

function zencode(a::AbstractArray, c::FixedScaleOffsetFilter{ScaleOffsetType, T, Tenc}) where {T, Tenc, ScaleOffsetType}
if Tenc <: Integer
return [round(Tenc, (a - c.offset) * c.scale) for a in a] # apply scale and offset, and round to nearest integer
else
return [convert(Tenc, (a - c.offset) * c.scale) for a in a] # apply scale and offset
end
end

function zdecode(a::AbstractArray, c::FixedScaleOffsetFilter{ScaleOffsetType, T, Tenc}) where {T, Tenc, ScaleOffsetType}
return [convert(Base.nonmissingtype(T), (a / c.scale) + c.offset) for a in a]
end


function getfilter(::Type{<: FixedScaleOffsetFilter}, d::Dict)
scale = d["scale"]
offset = d["offset"]
# Types must be converted from strings to the actual Julia types they represent.
string_T = d["dtype"]
string_Tenc = get(d, "astype", string_T)
T = typestr(string_T)
Tenc = typestr(string_Tenc)
return FixedScaleOffsetFilter{Tenc, T, Tenc}(scale, offset)
end

function JSON.lower(c::FixedScaleOffsetFilter{ScaleOffsetType, T, Tenc}) where {ScaleOffsetType, T, Tenc}
return Dict("id" => "fixedscaleoffset", "scale" => c.scale, "offset" => c.offset, "dtype" => typestr(T), "astype" => typestr(Tenc))
end

filterdict["fixedscaleoffset"] = FixedScaleOffsetFilter
85 changes: 85 additions & 0 deletions src/Filters/fletcher32.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
#=
# Fletcher32 filter
This "filter" basically injects a 4-byte checksum at the end of the data, to ensure data integrity.
The implementation is based on the [numcodecs implementation here](https://github.com/zarr-developers/numcodecs/blob/79d1a8d4f9c89d3513836aba0758e0d2a2a1cfaf/numcodecs/fletcher32.pyx)
and the [original C implementation for NetCDF](https://github.com/Unidata/netcdf-c/blob/main/plugins/H5checksum.c#L109) linked therein.
=#

"""
Fletcher32Filter()
A compressor that uses the Fletcher32 checksum algorithm to compress and uncompress data.
Note that this goes from UInt8 to UInt8, and is effectively only checking
the checksum and cropping the last 4 bytes of the data during decoding.
"""
struct Fletcher32Filter <: Filter{UInt8, UInt8}
end

getfilter(::Type{<: Fletcher32Filter}, d::Dict) = Fletcher32Filter()
JSON.lower(::Fletcher32Filter) = Dict("id" => "fletcher32")
filterdict["fletcher32"] = Fletcher32Filter

function _checksum_fletcher32(data::AbstractArray{UInt8})
len = length(data) ÷ 2 # length in 16-bit words
sum1::UInt32 = 0
sum2::UInt32 = 0
data_idx = 1

#=
Compute the checksum for pairs of bytes.
The magic `360` value is the largest number of sums that can be performed without overflow in UInt32.
=#
while len > 0
tlen = len > 360 ? 360 : len
len -= tlen
while tlen > 0
sum1 += begin # create a 16 bit word from two bytes, the first one shifted to the end of the word
(UInt16(data[data_idx]) << 8) | UInt16(data[data_idx + 1])
end
sum2 += sum1
data_idx += 2
tlen -= 1
if tlen < 1
break
end
end
sum1 = (sum1 & 0xffff) + (sum1 >> 16)
sum2 = (sum2 & 0xffff) + (sum2 >> 16)
end

# if the length of the data is odd, add the first byte to the checksum again (?!)
if length(data) % 2 == 1
sum1 += UInt16(data[1]) << 8
sum2 += sum1
sum1 = (sum1 & 0xffff) + (sum1 >> 16)
sum2 = (sum2 & 0xffff) + (sum2 >> 16)
end
return (sum2 << 16) | sum1
end

function zencode(data, ::Fletcher32Filter)
bytes = reinterpret(UInt8, vec(data))
checksum = _checksum_fletcher32(bytes)
result = copy(bytes)
append!(result, reinterpret(UInt8, [checksum])) # TODO: decompose this without the extra allocation of wrapping in Array
return result
end

function zdecode(data, ::Fletcher32Filter)
bytes = reinterpret(UInt8, data)
checksum = _checksum_fletcher32(view(bytes, 1:length(bytes) - 4))
stored_checksum = only(reinterpret(UInt32, view(bytes, (length(bytes) - 3):length(bytes))))
if checksum != stored_checksum
throw(ErrorException("""
Checksum mismatch in Fletcher32 decoding.
The computed value is $(checksum) and the stored value is $(stored_checksum).
This might be a sign that the data is corrupted.
""")) # TODO: make this a custom error type
end
return view(bytes, 1:length(bytes) - 4)
end
52 changes: 52 additions & 0 deletions src/Filters/quantize.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
#=
# Quantize compression
=#

"""
QuantizeFilter(; digits, DecodingType, [EncodingType = DecodingType])
Quantization based compression for Zarr arrays.
"""
struct QuantizeFilter{T, TENC} <: Filter{T, TENC}
digits::Int32
end

function QuantizeFilter(; digits = 10, T = Float16, Tenc = T)
return QuantizeFilter{T, Tenc}(digits)
end

QuantizeFilter{T, Tenc}(; digits = 10) where {T, Tenc} = QuantizeFilter{T, Tenc}(digits)
QuantizeFilter{T}(; digits = 10) where T = QuantizeFilter{T, T}(digits)

function zencode(data::AbstractArray, filter::QuantizeFilter{DecodingType, EncodingType}) where {DecodingType, EncodingType}
arr = reinterpret(DecodingType, vec(data))

precision = 10.0^(-filter.digits)

_exponent = log(10, precision) # log 10 in base `precision`
exponent = _exponent < 0 ? floor(Int, _exponent) : ceil(Int, _exponent)

bits = ceil(log(2, 10.0^(-exponent)))
scale = 2.0^bits

enc = @. convert(EncodingType, round(scale * arr) / scale)

return enc
end

# Decoding is a no-op; quantization is a lossy filter but data is encoded directly.
function zdecode(data::AbstractArray, filter::QuantizeFilter{DecodingType, EncodingType}) where {DecodingType, EncodingType}
return data
end

function JSON.lower(filter::QuantizeFilter{T, Tenc}) where {T, Tenc}
return Dict("id" => "quantize", "digits" => filter.digits, "dtype" => typestr(T), "astype" => typestr(Tenc))
end

function getfilter(::Type{<: QuantizeFilter}, d)
return QuantizeFilter{typestr(d["dtype"], typestr(d["astype"]))}(; digits = d["digits"])
end

filterdict["quantize"] = QuantizeFilter
Loading

0 comments on commit b52be51

Please sign in to comment.