Skip to content

Commit

Permalink
directly use _filt_iir!
Browse files Browse the repository at this point in the history
avoids unnecessary allocations esp. if x has >1 column
  • Loading branch information
wheeheee committed Oct 27, 2024
1 parent a440fa8 commit ae6be31
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 15 deletions.
30 changes: 18 additions & 12 deletions src/Filters/filt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -263,18 +263,19 @@ end

# Zero phase digital filtering by processing data in forward and reverse direction
function iir_filtfilt(b::AbstractVector, a::AbstractVector, x::AbstractArray)
zi = filt_stepstate(b, a)
pad_length = min(3 * (max(length(a), length(b)) - 1), size(x, 1) - 1)
t = Base.promote_eltype(b, a, x)
zi, bn, an = filt_stepstate(b, a)
t = Base.promote_eltype(bn, an, x)
zitmp = similar(zi, t)
extrapolated = Vector{t}(undef, size(x, 1) + 2 * pad_length)
out = similar(x, t)

for i = 1:Base.trailingsize(x, 2)
istart = 1 + (i - 1) * size(x, 1)
extrapolate_signal!(extrapolated, 1, x, istart, size(x, 1), pad_length)
reverse!(filt!(extrapolated, b, a, extrapolated, mul!(zitmp, zi, extrapolated[1])))
filt!(extrapolated, b, a, extrapolated, mul!(zitmp, zi, extrapolated[1]))
_filt_iir!(extrapolated, bn, an, extrapolated, mul!(zitmp, zi, extrapolated[1]), 1)
reverse!(extrapolated)
_filt_iir!(extrapolated, bn, an, extrapolated, mul!(zitmp, zi, extrapolated[1]), 1)
for j = 1:size(x, 1)
out[j, i] = extrapolated[end-pad_length+1-j]
end
Expand All @@ -291,7 +292,7 @@ coefficients `coef`. The initial state of the filter is computed so
that its response to a step function is steady state. Before
filtering, the data is extrapolated at both ends with an
odd-symmetric extension of length
`3*(max(length(b), length(a))-1)`.
`min(3*(max(length(b), length(a))-1), size(x, 1) - 1)`
Because `filtfilt` applies the given filter twice, the effective
filter order is twice the order of `coef`. The resulting signal has
Expand Down Expand Up @@ -367,33 +368,38 @@ filtfilt(f::PolynomialRatio{:z}, x) = filtfilt(coefb(f), coefa(f), x)
## Initial filter state

# Compute an initial state for filt with coefficients (b,a) such that its
# response to a step function is steady state.
function filt_stepstate(b::Union{AbstractVector{T}, T}, a::Union{AbstractVector{T}, T}) where T<:Number
# response to a step function is steady state. Also returns padded (b, a).
function filt_stepstate(b::AbstractVector{V}, a::AbstractVector{V}) where V<:Number
T = typeof(one(V) / one(V))
scale_factor = a[1]
if !isone(scale_factor)
a = a ./ scale_factor
b = b ./ scale_factor
elseif T !== V
a = convert.(T, a)
b = convert.(T, b)
end

bs = length(b)
as = length(a)
sz = max(bs, as)
sz > 0 || throw(ArgumentError("a and b must have at least one element each"))
sz == 1 && return T[]

# Pad the coefficients with zeros if needed
bs<sz && (b = copyto!(zeros(eltype(b), sz), b))
as<sz && (a = copyto!(zeros(eltype(a), sz), a))
bs < sz && (b = copyto!(zeros(T, sz), b))
as < sz && (a = copyto!(zeros(T, sz), a))
sz == 1 && return (T[], b, a)

# construct the companion matrix A and vector B:
A = [-a[2:end] Matrix{T}(I, sz-1, sz-2)]
B = @views @. muladd(a[2:end], -b[1], b[2:end])
# Solve si = A*si + B
# (I - A)*si = B
((I - A) \ B) .*= scale_factor
si = (((I - A) \ B) .*= scale_factor)
return (si, b, a)
end

filt_stepstate(b::Union{AbstractVector{T}, T}, a::Union{AbstractVector{V}, V}) where {T<:Number,V<:Number} =
filt_stepstate(b::AbstractVector{<:Number}, a::AbstractVector{<:Number}) =
filt_stepstate(promote(b, a)...)

function filt_stepstate(f::SecondOrderSections{:z,T}) where T
Expand Down
6 changes: 3 additions & 3 deletions test/filt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ end
b = [ 0.00327922, 0.01639608, 0.03279216, 0.03279216, 0.01639608, 0.00327922]
a = [ 1. , -2.47441617, 2.81100631, -1.70377224, 0.54443269, -0.07231567]

@test (zi_python, DSP.Filters.filt_stepstate(b, a), atol=1e-7)
@test (zi_python, DSP.Filters.filt_stepstate(b, a)[1], atol=1e-7)

##############
#
Expand All @@ -99,7 +99,7 @@ end
b = [0.222, 0.43, 0.712]
a = [1, 0.33, 0.22]

@test zi_matlab DSP.Filters.filt_stepstate(b, a)
@test zi_matlab DSP.Filters.filt_stepstate(b, a)[1]


##############
Expand All @@ -118,7 +118,7 @@ end
b = [ 0.00327922, 0.01639608, 0.03279216, 0.03279216, 0.01639608, 0.00327922]
a = [ 1.1 , -2.47441617, 2.81100631, -1.70377224, 0.54443269, -0.07231567]

@test (zi_python, DSP.Filters.filt_stepstate(b, a), atol=1e-7)
@test (zi_python, DSP.Filters.filt_stepstate(b, a)[1], atol=1e-7)
end


Expand Down

0 comments on commit ae6be31

Please sign in to comment.