diff --git a/src/impl/batched_mul.jl b/src/impl/batched_mul.jl index 911e7927..ce68fac1 100644 --- a/src/impl/batched_mul.jl +++ b/src/impl/batched_mul.jl @@ -54,14 +54,15 @@ function batched_matmul!(z::AbstractArray{zT, 3}, ::LoopedArrayOp, return end -function batched_matmul_cpu!(z::AbstractArray{zT, 3}, - x::AbstractArray{xT, 3}, y::AbstractArray{yT, 3}) where {zT, xT, yT} +function batched_matmul_cpu!( + z::AbstractArray{zT, 3}, x::AbstractArray{xT, 3}, y::AbstractArray{yT, 3}, + α::Number=true, β::Number=false) where {zT, xT, yT} if can_loopvec_args(batchview(z, 1), batchview(x, 1), batchview(y, 1)) && !unsafe_known(explicit_blas_loaded()) - batched_matmul_loopvec_impl!(z, x, y) + batched_matmul_loopvec_impl!(z, x, y, α, β) return end - NNlib.batched_mul!(z, x, y) + NNlib.batched_mul!(z, x, y, α, β) return end @@ -120,7 +121,7 @@ end # This is type-piracy but needed to fix a blocking issue. TODO: upstream to NNlib # Enzyme causes a "active variables passed by value to jl_new_task are not yet supported" # warning without this patch. -for func in (NNlib.batched_mul!, batched_matmul_cpu!) +for func in (NNlib.batched_mul!, batched_matmul_loopvec_impl!) @eval begin function EnzymeRules.augmented_primal( cfg::EnzymeRules.RevConfigWidth, ::EnzymeCore.Const{typeof($(func))}, diff --git a/src/utils.jl b/src/utils.jl index c96e4611..eaa60f08 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -18,6 +18,9 @@ const KA = KernelAbstractions is_extension_loaded(::Val) = False() +CRC.@non_differentiable is_extension_loaded(::Any...) +EnzymeRules.inactive_noinl(::typeof(is_extension_loaded), ::Any...) = nothing + # Simple Operations -- no rrules needed ofeltype_array(::Type{T}, x::AbstractArray{T}) where {T} = x function ofeltype_array( @@ -328,4 +331,8 @@ end @inline can_loopvec_args_check(::False, args...) = false +CRC.@non_differentiable can_loopvec_args_check(::Any...) + +EnzymeRules.inactive_noinl(::typeof(can_loopvec_args_check), ::Any...) = nothing + end diff --git a/test/common_ops/activation_tests.jl b/test/common_ops/activation_tests.jl index 2045f20f..e2b80e71 100644 --- a/test/common_ops/activation_tests.jl +++ b/test/common_ops/activation_tests.jl @@ -36,7 +36,7 @@ @jet apply_act_fast2(f, x) @test @inferred(Zygote.gradient(apply_act, f, x)) isa Any - if f !== lisht || (f === lisht && T == Float32 && !ongpu) + if f !== lisht @test @inferred(Zygote.gradient(apply_act_fast, f, x)) isa Any end @test @inferred(Zygote.gradient(apply_act_fast2, f, x)) isa Any diff --git a/test/common_ops/bias_act_tests.jl b/test/common_ops/bias_act_tests.jl index 1429c9b2..a7499654 100644 --- a/test/common_ops/bias_act_tests.jl +++ b/test/common_ops/bias_act_tests.jl @@ -44,12 +44,11 @@ @jet bias_act_loss2(act, x, b) @jet bias_act_loss3(act, x, b) - if (act !== lisht || (act === lisht && T == Float32 && !ongpu)) && T != Float16 - @test @inferred(Zygote.gradient(bias_act_loss2, act, x, b)) isa Any - @test @inferred(Zygote.gradient(bias_act_loss3, act, x, b)) isa Any - elseif T != Float16 - @test_broken @inferred(Zygote.gradient(bias_act_loss2, act, x, b)) isa Any - @test_broken @inferred(Zygote.gradient(bias_act_loss3, act, x, b)) isa Any + if act !== lisht + @test @inferred(Zygote.gradient(bias_act_loss2, act, x, b)) isa Any broken=(T != + Float16) + @test @inferred(Zygote.gradient(bias_act_loss3, act, x, b)) isa Any broken=(T != + Float16) end @test_gradients(__Fix1(bias_act_loss1, act), x, b; atol, rtol,