Skip to content

Commit

Permalink
Split up tests. Test Float16 and Complex{Float32}. Test overridin…
Browse files Browse the repository at this point in the history
…g `mul!`
  • Loading branch information
RomeoV committed Jul 5, 2024
1 parent e752f9a commit 32bf475
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 45 deletions.
54 changes: 9 additions & 45 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,62 +4,26 @@ import SparseArrays: mul!, sprand
import Profile

# Helper function to run common test logic
function run_common_tests(method!, buf::AbstractMatrix{T}, lhs, rhs, α, β, baseline) where {T <: Real}
function run_common_tests(method!, buf::AbstractMatrix{T}, lhs, rhs, α, β, baseline) where {T <: Number}
method!(buf, lhs, rhs, α, β)
@test buf baseline rtol=sqrt(eps(T))
@test buf baseline rtol=sqrt(eps(real(T)))
@test !any(isnan, buf)

# Test with negative α
method!(buf, lhs, rhs, -α, β)
method!(buf, lhs, rhs, α, β)
@test buf baseline rtol=sqrt(eps(T))
@test buf baseline rtol=sqrt(eps(real(T)))
@test !any(isnan, buf)
end

@testset "ThreadedDenseSparseMul Tests" begin
@test ThreadedDenseSparseMul.get_num_threads() == Threads.nthreads()
@testset "Dense-Sparse Multiplication" begin
@testset "$T type" for T in [Float64, Float32]
@testset "$method! implementation" for method! in [fastdensesparsemul!, fastdensesparsemul_threaded!]
@testset "Trial $trial" for trial in 1:10
lhs = rand(T, 50, 100)
rhs = sprand(T, 100, 1_000, 0.1)
baseline = lhs * rhs

buf = similar(baseline) .* false # fill buffer with zeros. Carefull with NaNs, see https://discourse.julialang.org/t/occasionally-nans-when-using-similar/48224/12

# Test basic multiplication
run_common_tests(method!, buf, lhs, rhs, 1, 0, baseline)

# Test @view interface and β ≠ 0
inds = rand(axes(lhs, 1), size(lhs, 1) ÷ 3)
baseline[inds, :] .+= 2.5 * @view(lhs[inds, :]) * rhs

run_common_tests(method!, @view(buf[inds, :]), @view(lhs[inds, :]), rhs, 2.5, 1, @view(baseline[inds, :]))
end
end
end
end

@testset "Outer Product Multiplication" begin
@testset "$T type" for T in [Float64, Float32]
@testset "$method! implementation" for method! in [fastdensesparsemul_outer!, fastdensesparsemul_outer_threaded!]
@testset "Trial $trial" for trial in 1:10
lhs = rand(T, 50, 100)
rhs = sprand(T, 100, 1_000, 0.1)
k = rand(axes(rhs, 1))

baseline = lhs[:, k:k] * rhs[k:k, :]
buf = similar(baseline) .* false

# Test basic outer product multiplication
run_common_tests(method!, buf, @view(lhs[:, k]), rhs[k, :], 1, 0, baseline)

# Test with β ≠ 0
baseline .+= 2.5 * lhs[:, (k+1):(k+1)] * rhs[(k+1):(k+1), :]
run_common_tests(method!, buf, lhs[:, k+1], rhs[k+1, :], 2.5, 1, baseline)
end
end
@testset "Override mul! ?" for override_mul! in [false, true]
override_mul! && ThreadedDenseSparseMul.override_mul!()
@testset "nthreads" for nthreads in [1, Threads.nthreads()]
ThreadedDenseSparseMul.set_num_threads(nthreads)
include("test_densesparsemul.jl")
include("test_densesparseouter.jl")
end
end
end
22 changes: 22 additions & 0 deletions test/test_densesparsemul.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
@testset "Dense-Sparse Multiplication" begin
@testset "$T type" for T in [Float64, Float32, Float16, Complex{Float32}]
@testset "$method! implementation" for method! in [fastdensesparsemul!, fastdensesparsemul_threaded!]
@testset "Trial $trial" for trial in 1:10
lhs = rand(T, 50, 100)
rhs = sprand(T, 100, 1_000, 0.1)
baseline = lhs * rhs

buf = similar(baseline) .* false # fill buffer with zeros. Carefull with NaNs, see https://discourse.julialang.org/t/occasionally-nans-when-using-similar/48224/12

# Test basic multiplication
run_common_tests(method!, buf, lhs, rhs, 1, 0, baseline)

# Test @view interface and β ≠ 0
inds = rand(axes(lhs, 1), size(lhs, 1) ÷ 3)
baseline[inds, :] .+= 2.5 * @view(lhs[inds, :]) * rhs

run_common_tests(method!, @view(buf[inds, :]), @view(lhs[inds, :]), rhs, 2.5, 1, @view(baseline[inds, :]))
end
end
end
end
21 changes: 21 additions & 0 deletions test/test_densesparseouter.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
@testset "Outer Product Multiplication" begin
@testset "$T type" for T in [Float64, Float32]
@testset "$method! implementation" for method! in [fastdensesparsemul_outer!, fastdensesparsemul_outer_threaded!]
@testset "Trial $trial" for trial in 1:10
lhs = rand(T, 50, 100)
rhs = sprand(T, 100, 1_000, 0.1)
k = rand(axes(rhs, 1))

baseline = lhs[:, k:k] * rhs[k:k, :]
buf = similar(baseline) .* false

# Test basic outer product multiplication
run_common_tests(method!, buf, @view(lhs[:, k]), rhs[k, :], 1, 0, baseline)

# Test with β ≠ 0
baseline .+= 2.5 * lhs[:, (k+1):(k+1)] * rhs[(k+1):(k+1), :]
run_common_tests(method!, buf, lhs[:, k+1], rhs[k+1, :], 2.5, 1, baseline)
end
end
end
end

0 comments on commit 32bf475

Please sign in to comment.