Skip to content

Commit

Permalink
Merge pull request #18 from andrewrosemberg/ar/master
Browse files Browse the repository at this point in the history
Fix stuff
  • Loading branch information
ancorso authored Nov 24, 2024
2 parents 2ea63c3 + 9184e7e commit 9c88f30
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions src/extras/spectral_normalization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ struct ConvSN{N,M,F,A,V, I<:Int, VV<:AbstractArray}
u::VV # Left vector for power iteration
end

function ConvSN(w::AbstractArray{T,N}, b::Union{Flux.Zeros, AbstractVector{T}}, σ = identity;
function ConvSN(w::AbstractArray{T,N}, b, σ = identity;
stride = 1, pad = 0, dilation = 1, n_iterations = 1) where {T,N}
stride = Flux.expand(Val(N-2), stride)
dilation = Flux.expand(Val(N-2), dilation)
Expand All @@ -73,7 +73,7 @@ end

function ConvSN(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity;
init = Flux.glorot_uniform, stride = 1, pad = 0, dilation = 1,
weight = Flux.convfilter(k, ch, init = init), bias = Flux.zeros(ch[2]), n_iterations = 1) where N
weight = Flux.convfilter(k, ch, init = init), bias = zeros(ch[2]), n_iterations = 1) where N
ConvSN(weight, bias, σ, stride = stride, pad = pad, dilation = dilation, n_iterations = n_iterations)
end

Expand Down

0 comments on commit 9c88f30

Please sign in to comment.