Skip to content
This repository has been archived by the owner on Nov 4, 2024. It is now read-only.

Commit

Permalink
fix: enzyme segfault bypass
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Oct 18, 2024
1 parent 0b19476 commit d2f76dd
Showing 1 changed file with 15 additions and 4 deletions.
19 changes: 15 additions & 4 deletions src/impl/batched_mul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,9 @@ function batched_matmul_cpu!(z::AbstractArray{zT, 3}, x::AbstractArray{xT, 3},
batched_matmul_loopvec_impl!(z, x, y)
return
end
NNlib.batched_mul!(z, x, y)
# Avoid an Enzyme segfault https://github.com/EnzymeAD/Enzyme.jl/issues/1983
fallback_batched_matmul!(z, LoopedArrayOp(), x, y)
# NNlib.batched_mul!(z, x, y) # XXX: restore once the enzyme segfault is fixed
return
end

Expand All @@ -78,13 +80,18 @@ end
function fallback_batched_matmul!(
z::AbstractArray{zT, 3}, dev, x::AbstractArray{xT, 3},
y::AbstractArray{yT, 3}) where {zT, xT, yT}
@warn "Using fallback Batched Matrix Multiply routine for $(dev) with A: size = \
$(size(x)) eltype = $(xT) and B: size = $(size(y)) eltype = $(yT). This may be \
slow." maxlog=1
# XXX: bring back once the enzyme segfault is fixed
# @warn "Using fallback Batched Matrix Multiply routine for $(dev) with A: size = \
# $(size(x)) eltype = $(xT) and B: size = $(size(y)) eltype = $(yT). This may be \
# slow." maxlog=1

if (size(x, 3) != size(y, 3) && size(x, 3) != 1 && size(y, 3) != 1) ||
(size(x, 2) != size(y, 1))
throw(DimensionMismatch(lazy"size(x) = $(size(x)), size(y) = $(size(y)) inconsistent for batched_matmul."))
end

old_threads = maybe_reduce_BLAS_threads(z)

if size(x, 3) == size(y, 3)
Threads.@threads for L in axes(z, 3)
mul!(batchview(z, L), batchview(x, L), batchview(y, L))
Expand All @@ -98,6 +105,10 @@ function fallback_batched_matmul!(
mul!(batchview(z, L), batchview(x, L), batchview(y, 1))
end
end

reset_BLAS_threads(old_threads)

return
end

function CRC.rrule(::typeof(batched_matmul), x::AbstractArray{xT, 3},
Expand Down

0 comments on commit d2f76dd

Please sign in to comment.