Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Constrain type in to_vec(::AbstractArray/Vector) #156

Merged
merged 13 commits into from
Apr 28, 2021
8 changes: 4 additions & 4 deletions src/to_vec.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,13 @@ function to_vec(x::T) where {T}
v, vals_from_vec = to_vec(vals)
function structtype_from_vec(v::Vector{<:Real})
val_vecs = vals_from_vec(v)
vals = map((b, v) -> b(v), backs, val_vecs)
return T(vals...)
values = map((b, v) -> b(v), backs, val_vecs)
return T(values...)
end
return v, structtype_from_vec
end

function to_vec(x::StridedVector)
function to_vec(x::Union{SubArray, Base.ReshapedArray, StridedVector})
x_vecs_and_backs = map(to_vec, x)
x_vecs, backs = first.(x_vecs_and_backs), last.(x_vecs_and_backs)
function Vector_from_vec(x_vec)
Expand All @@ -53,7 +53,7 @@ function to_vec(x::StridedVector)
return x_vec, Vector_from_vec
end

function to_vec(x::Union{SubArray, StridedArray})
function to_vec(x::Union{SubArray, Base.ReshapedArray, StridedArray})
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe the parent array of SubArray should be restricted to e.g. StridedArray or something?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We want e.g. a

julia> sa = @view randn(T, 10)[1:4]
julia> typeof(sa)
SubArray{Float64, 1, Vector{Float64}, Tuple{UnitRange{Int64}}, true}

to dispatch on the current SubArray implementation. Unfortunately StridedArray is a Union of several more specific SubArrays (not just in the parent argument)

So we have two options:

  1. Narrow the dispatch of the current SubArray (very verbose*)
  2. Loosen the dispatch for StridedArray to Union{StridedArray, SubArray} so that the current SubArray becomes more specific.

And similarly for the ReshapedArray.

I don't particularly like either option but it's probably better than the status quo.

*Number 1) looks like this, I want to nope out of that 😂

StridedSubArray = Union{
SubArray{T, N, A, I, true} where {T, N, A<:DenseArray, I<:Union{Tuple{Vararg{Real, N} where N}, Tuple{AbstractUnitRange, Vararg{Any, N} where N}}}, DenseArray}, MI<:Tuple{Vararg{Base.MultiplicativeInverses.SignedMultiplicativeInverse{Int64}, N} where N}},
SubArray{T, 1, A, I, L} where {A<:Union{Base.ReinterpretArray{T, N, S, A, IsReshaped} where {T, N, A<:Union{SubArray{T, N, A, I, true} where {T, N, A<:DenseArray, I<:Union{Tuple{Vararg{Real, N} where N}, Tuple{AbstractUnitRange, Vararg{Any, N} where N}}}, DenseArray}, IsReshaped, S},
SubArray{T, N, A, I, true} where {T, N, A<:DenseArray, I<:Union{Tuple{Vararg{Real, N} where N}, Tuple{AbstractUnitRange, Vararg{Any, N} where N}}}, DenseArray}, MI<:Tuple{Vararg{Base.MultiplicativeInverses.SignedMultiplicativeInverse{Int64}, N} where N}}, DenseArray}, I<:Tuple{Vararg{Union{Int64, AbstractRange{Int64}, Base.AbstractCartesianIndex, Base.ReshapedArray{T, N, A, Tuple{}} where {T, N, A<:AbstractUnitRange}}, N} where N}, L}} where T)
}

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@willtebbutt how do you feel about that?

Copy link
Member

@willtebbutt willtebbutt Apr 23, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The thing I'm concerned about is if someone attempts to to_vec something like

view(Diagonal(randn(5)), 1:5, 1:3)

*Number 1) looks like this, I want to nope out of that 😂

Lol yeah, that's really not fun.

Is there a reason that we can't go with something like

Union{StridedArray, SubArray{T, D, <:StridedArray} where {T, D}}

and just recurse into ReshapedArrays?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The dispatch for the Diagonal example is the same (the dedicated SubArray method) whether we constrain the parent of the SubArray to be strided or not.

The example breaks, but because SparseArray.SparseMatrixCSC does not have a to_vec defined.

Copy link
Member Author

@mzgubic mzgubic Apr 26, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The status is:

  • All tests pass with the current PR
  • Removing the to_vec(Base.SubArray) breaks some tests
  • test_to_vec(view(Diagonal(randn(5)), 1:5, 1:3); check_inferred=false), (the thing you suggested should be checked) breaks on master (StackOverflowError), on the current PR (SparseMatrixCSC error), and on the current PR with SubArray removed from the Union (SparseMatrixCSC error).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ahh I see.

I've run it further locally, and found that the following seems to resolve everything for me:

  1. restrict the array methods to DenseVector and DenseArray, from Union{StridedArray...}
  2. add a specific method for SubArray:
function to_vec(x::SubArray)
    x_vec, from_vec = to_vec(x.parent)
    SubArray_from_vec(x_vec) = view(from_vec(x_vec), x.indices...)
    return x_vec, SubArray_from_vec
end

if you don't have this, it thinks that the integer fields should also be to_veced.
3. comment out the SubArray-specific method

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh yeah that's much better. Are we happy with that?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm happy with it, but I feel that we should get an additional review. Maybe @oxinabox or @sethaxen ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually with this one ChainRules test fails, and the reason is that the current implementation returns the whole parent vector

julia> a = rand(7);

julia> b = @view a[2:5];

julia> v, back = to_vec(b);

julia> v
7-element Vector{Float64}:
 0.9586867155061674
 0.8025760289977295
 0.9784238101274141
 0.16440182965461236
 0.17784328126321847
 0.01704930118685577
 0.8408143603848348

Keeping the old version below works just fine. No issues now StridedArray has been replaced by DenseArray (where we used to have AbstractArray)

to_vec(x::SubArray) = to_vec(copy(x))

x_vec, from_vec = to_vec(vec(x))

function Array_from_vec(x_vec)
Expand Down
17 changes: 7 additions & 10 deletions test/to_vec.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,15 @@ end
Base.:(==)(x::DummyType, y::DummyType) = x.X == y.X
Base.length(x::DummyType) = size(x.X, 1)

# A dummy FillVector. This is a type for which the fallback implementation of
# `to_vec` should fail loudly.
# A dummy FillVector
struct FillVector <: AbstractVector{Float64}
x::Float64
len::Int
end

Base.size(x::FillVector) = (x.len,)
Base.getindex(x::FillVector, n::Int) = x.x

# For testing Composite{ThreeFields}
struct ThreeFields
a
Expand All @@ -32,9 +34,6 @@ struct Nested
y::Singleton
end

Base.size(x::FillVector) = (x.len,)
Base.getindex(x::FillVector, n::Int) = x.x

function test_to_vec(x::T; check_inferred = true) where {T}
check_inferred && @inferred to_vec(x)
x_vec, back = to_vec(x)
Expand Down Expand Up @@ -67,8 +66,8 @@ end
test_to_vec(UpperTriangular(randn(T, 13, 13)))
test_to_vec(Diagonal(randn(T, 7)))
test_to_vec(DummyType(randn(T, 2, 9)))
test_to_vec(SVector{2, T}(1.0, 2.0))
test_to_vec(SMatrix{2, 2, T}(1.0, 2.0, 3.0, 4.0))
test_to_vec(SVector{2, T}(1.0, 2.0); check_inferred = false)
test_to_vec(SMatrix{2, 2, T}(1.0, 2.0, 3.0, 4.0); check_inferred = false)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
test_to_vec(SVector{2, T}(1.0, 2.0); check_inferred = false)
test_to_vec(SMatrix{2, 2, T}(1.0, 2.0, 3.0, 4.0); check_inferred = false)
test_to_vec(SVector{2, T}(1.0, 2.0); check_inferred=false)
test_to_vec(SMatrix{2, 2, T}(1.0, 2.0, 3.0, 4.0); check_inferred=false)

Bluetyle spacing in kwargs

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changed in other places as well now

test_to_vec(@view randn(T, 10)[1:4]) # SubArray -- Vector
test_to_vec(@view randn(T, 10, 2)[1:4, :]) # SubArray -- Matrix
test_to_vec(Base.ReshapedArray(rand(T, 3, 3), (9,), ()))
Expand Down Expand Up @@ -173,9 +172,7 @@ end
end

@testset "FillVector" begin
x = FillVector(5.0, 10)
x_vec, from_vec = to_vec(x)
@test_throws MethodError from_vec(randn(10))
test_to_vec(FillVector(5.0, 10); check_inferred=false)
end

@testset "fallback" begin
Expand Down