From 1f1f075a99fca859dac6b4d77a228c016eb2ce07 Mon Sep 17 00:00:00 2001 From: Anshul Singhvi Date: Wed, 23 Oct 2024 00:07:51 -0700 Subject: [PATCH] Implement collect_similar like collect for DiskGenerators (#198) * Implement `collect_similar` like `collect` for DiskGenerators * Add a test --- src/generator.jl | 25 +++++++++++++++++++++++++ test/runtests.jl | 14 ++++++++++++++ 2 files changed, 39 insertions(+) diff --git a/src/generator.jl b/src/generator.jl index be47e62..8619deb 100644 --- a/src/generator.jl +++ b/src/generator.jl @@ -47,6 +47,31 @@ function Base.collect(itr::DiskGenerator{<:AbstractArray{<:Any,N}}) where {N} return dest end +# Warning: this is not public API! +function Base.collect_similar(A::AbstractArray, itr::DiskGenerator{<:AbstractArray{<:Any,N}}) where {N} + y = iterate(itr) + shp = axes(itr.iter) + if y === nothing + et = Base.@default_eltype(itr) + return similar(A, et, shp) + end + v1, st = y + dest = similar(A, typeof(v1), shp) + i = y + for I in eachindex(itr.iter) + if i isa Nothing # Mainly to keep JET clean + error( + "Should not be reached: iterator is shorter than its `eachindex` iterator" + ) + else + dest[I] = first(i) + i = iterate(itr, last(i)) + end + end + return dest + +end + macro implement_generator(t) t = esc(t) quote diff --git a/test/runtests.jl b/test/runtests.jl index db8dd1c..5c9a487 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -953,3 +953,17 @@ end @test getindex_count(A) == 0 end +@testset "Map over indices correctly" begin + # This is a regression test for issue #144 + # `map` should always work over the correct indices, + # especially since we overload generators to `DiskArrayGenerator`. + + data = [i+j for i in 1:200, j in 1:100] + da = AccessCountDiskArray(data, chunksize=(10,10)) + @test map(identity, da) == data + @test all(map(identity, da) .== data) + + # Make sure that type inference works + @inferred Matrix{Int} map(identity, da) + @inferred Matrix{Float64} map(x -> x * 5.0, da) +end