diff --git a/src/extras/spectral_normalization.jl b/src/extras/spectral_normalization.jl index e1acab7..ea93f72 100644 --- a/src/extras/spectral_normalization.jl +++ b/src/extras/spectral_normalization.jl @@ -62,7 +62,7 @@ struct ConvSN{N,M,F,A,V, I<:Int, VV<:AbstractArray} u::VV # Left vector for power iteration end -function ConvSN(w::AbstractArray{T,N}, b::Union{Flux.Zeros, AbstractVector{T}}, σ = identity; +function ConvSN(w::AbstractArray{T,N}, b, σ = identity; stride = 1, pad = 0, dilation = 1, n_iterations = 1) where {T,N} stride = Flux.expand(Val(N-2), stride) dilation = Flux.expand(Val(N-2), dilation) @@ -73,7 +73,7 @@ end function ConvSN(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity; init = Flux.glorot_uniform, stride = 1, pad = 0, dilation = 1, - weight = Flux.convfilter(k, ch, init = init), bias = Flux.zeros(ch[2]), n_iterations = 1) where N + weight = Flux.convfilter(k, ch, init = init), bias = zeros(ch[2]), n_iterations = 1) where N ConvSN(weight, bias, σ, stride = stride, pad = pad, dilation = dilation, n_iterations = n_iterations) end