diff --git a/ext/AccessorsStaticArraysExt.jl b/ext/AccessorsStaticArraysExt.jl index 7f728f86..8d7e856c 100644 --- a/ext/AccessorsStaticArraysExt.jl +++ b/ext/AccessorsStaticArraysExt.jl @@ -1,16 +1,17 @@ module AccessorsStaticArraysExt -isdefined(Base, :get_extension) ? (import StaticArrays) : (import ..StaticArrays) +isdefined(Base, :get_extension) ? (using StaticArrays) : (using ..StaticArrays) using Accessors import Accessors: setindex, delete, insert -@inline setindex(a::StaticArrays.StaticArray, args...) = Base.setindex(a, args...) -@inline delete(obj::StaticArrays.SVector, l::IndexLens) = StaticArrays.deleteat(obj, only(l.indices)) -@inline insert(obj::StaticArrays.SVector, l::IndexLens, val) = StaticArrays.insert(obj, only(l.indices), val) +@inline setindex(a::StaticArray, args...) = Base.setindex(a, args...) +@inline delete(obj::StaticVector, l::IndexLens) = StaticArrays.deleteat(obj, only(l.indices)) +@inline insert(obj::StaticVector, l::IndexLens, val) = StaticArrays.insert(obj, only(l.indices), val) -Accessors.set(obj::StaticArrays.SVector, ::Type{Tuple}, val::Tuple) = StaticArrays.SVector(val) +Accessors.set(obj::StaticVector, ::Type{Tuple}, val::Tuple) = constructorof(typeof(obj))(val...) +Accessors.set(obj::Tuple, ::Type{<:StaticVector}, val::StaticVector) = Tuple(val) -Accessors.getall(obj::StaticArrays.StaticArray, ::Elements) = Tuple(obj) -Accessors.setall(obj::StaticArrays.StaticArray, ::Elements, vs::AbstractArray) = constructorof(typeof(obj))(vs...) # just for disambiguation -Accessors.setall(obj::StaticArrays.StaticArray, ::Elements, vs) = constructorof(typeof(obj))(vs...) +Accessors.getall(obj::StaticArray, ::Elements) = Tuple(obj) +Accessors.setall(obj::StaticArray, ::Elements, vs::AbstractArray) = constructorof(typeof(obj))(vs...) # just for disambiguation +Accessors.setall(obj::StaticArray, ::Elements, vs) = constructorof(typeof(obj))(vs...) end diff --git a/src/functionlenses.jl b/src/functionlenses.jl index 66b61495..1f5911b7 100644 --- a/src/functionlenses.jl +++ b/src/functionlenses.jl @@ -1,12 +1,13 @@ using LinearAlgebra: norm, normalize using Dates -set(obj, ::typeof(last), val) = @set obj[lastindex(obj)] = val +# first and last on general indexable collections set(obj, ::typeof(first), val) = @set obj[firstindex(obj)] = val -delete(obj, ::typeof(last)) = delete(obj, IndexLens((lastindex(obj),))) +set(obj, ::typeof(last), val) = @set obj[lastindex(obj)] = val delete(obj, ::typeof(first)) = delete(obj, IndexLens((firstindex(obj),))) -insert(obj, ::typeof(last), val) = insert(obj, IndexLens((lastindex(obj) + 1,)), val) +delete(obj, ::typeof(last)) = delete(obj, IndexLens((lastindex(obj),))) insert(obj, ::typeof(first), val) = insert(obj, IndexLens((firstindex(obj),)), val) +insert(obj, ::typeof(last), val) = insert(obj, IndexLens((lastindex(obj) + 1,)), val) set(obj, o::Base.Fix2{typeof(first)}, val) = @set obj[firstindex(obj):(firstindex(obj) + o.x - 1)] = val set(obj, o::Base.Fix2{typeof(last)}, val) = @set obj[(lastindex(obj) - o.x + 1):lastindex(obj)] = val @@ -15,12 +16,17 @@ delete(obj, o::Base.Fix2{typeof(last)}) = @delete obj[(lastindex(obj) - o.x + 1) insert(obj, o::Base.Fix2{typeof(first)}, val) = @insert obj[firstindex(obj):(firstindex(obj) + o.x - 1)] = val insert(obj, o::Base.Fix2{typeof(last)}, val) = @insert obj[(lastindex(obj) + 1):(lastindex(obj) + o.x)] = val +# first and last on ranges +# they don't support delete() with arbitrary index, so special casing is needed +delete(obj::AbstractRange, ::typeof(first)) = obj[begin+1:end] +delete(obj::AbstractRange, ::typeof(last)) = obj[begin:end-1] +delete(obj::AbstractRange, o::Base.Fix2{typeof(first)}) = obj[begin+o.x:end] +delete(obj::AbstractRange, o::Base.Fix2{typeof(last)}) = obj[begin:end-o.x] + + set(obj::Tuple, ::typeof(Base.front), val::Tuple) = (val..., last(obj)) set(obj::Tuple, ::typeof(Base.tail), val::Tuple) = (first(obj), val...) -set(obj, ::typeof(identity), val) = val -set(obj, ::typeof(inv), new_inv) = inv(new_inv) - function set(obj, ::typeof(only), val) only(obj) # error check set(obj, first, val) @@ -42,6 +48,8 @@ function set(obj::NamedTuple, ::Type{NamedTuple{KS}}, val::NamedTuple) where {KS setproperties(obj, NamedTuple{KS}(val)) end +set(obj, ::typeof(Base.splat(=>)), val::Pair) = @set Tuple(obj) = Tuple(val) + set(obj, ::typeof(getproperties), val::NamedTuple) = setproperties(obj, val) ################################################################################ @@ -125,10 +133,13 @@ set(x, f::Base.Fix2{typeof(rem)}, y) = set(x, @optic(last(divrem(_, f.x))), y) set(x::AbstractString, f::Base.Fix1{typeof(parse), Type{T}}, y::T) where {T} = string(y) set(arr, ::typeof(normalize), val) = norm(arr) * val -set(arr, ::typeof(norm), val) = val/norm(arr) * arr # should we check val is positive? +set(arr, ::typeof(norm), val) = map(Base.Fix2(*, val / norm(arr)), arr) # should we check val is positive? set(f, ::typeof(inverse), invf) = setinverse(f, invf) +set(obj, ::typeof(Base.splat(atan)), val) = @set Tuple(obj) = norm(obj) .* sincos(val) +set(obj, ::typeof(Base.splat(hypot)), val) = @set norm(obj) = val + ################################################################################ ##### dates ################################################################################ @@ -136,6 +147,11 @@ set(x::DateTime, ::Type{Date}, y) = DateTime(y, Time(x)) set(x::DateTime, ::Type{Time}, y) = DateTime(Date(x), y) set(x::T, ::Type{T}, y) where {T <: Union{Date, Time}} = y +# directly mirrors Dates.value implementation in stdlib +set(x::Date, ::typeof(Dates.value), y) = @set x.instant.periods.value = y +set(x::DateTime, ::typeof(Dates.value), y) = @set x.instant.periods.value = y +set(x::Time, ::typeof(Dates.value), y) = @set x.instant.value = y + set(x::Date, ::typeof(year), y) = Date(y, month(x), day(x)) set(x::Date, ::typeof(month), y) = Date(year(x), y, day(x)) set(x::Date, ::typeof(day), y) = Date(year(x), month(x), y) diff --git a/src/optics.jl b/src/optics.jl index f9f6213e..18ddec7f 100644 --- a/src/optics.jl +++ b/src/optics.jl @@ -436,6 +436,9 @@ end @inline insert(obj::NamedTuple, l::IndexLens{Tuple{Symbol}}, val) = merge(obj, NamedTuple{l.indices}((val,))) @inline insert(obj::NamedTuple, l::IndexLens{<:Tuple{Tuple{Vararg{Symbol}}}}, vals) = merge(obj, NamedTuple{only(l.indices)}(vals)) +@inline delete(obj::CartesianIndex, l::IndexLens{Tuple{Int}}) = delete(obj, l ∘ Tuple) +@inline insert(obj::CartesianIndex, l::IndexLens{Tuple{Int}}, val) = insert(obj, l ∘ Tuple, val) + struct DynamicIndexLens{F} f::F end diff --git a/test/Project.toml b/test/Project.toml index 7f6291ca..74bb2335 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -8,6 +8,7 @@ Dates = "ade2ca70-3891-5945-98fb-dc099432e06a" InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" IntervalSets = "8197267c-284f-5f27-9208-e0e47529a953" InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" PerformanceTestTools = "dc46b164-d16f-48ec-a853-60448fc869fe" QuickTypes = "ae2dfa86-617c-530c-b392-ef20fdad97bb" diff --git a/test/test_delete.jl b/test/test_delete.jl index 6fd70612..edc3b414 100644 --- a/test/test_delete.jl +++ b/test/test_delete.jl @@ -20,6 +20,11 @@ using StaticArrays @test @inferred(delete( [1, 2, 3], @optic(first(_, 2)))) == [3] @test @inferred(delete( [1, 2, 3], @optic(last(_, 2)))) == [1] + @test @inferred(delete(CartesianIndex(1, 2, 3), @optic(_[1]))) == CartesianIndex(2, 3) + + @test @inferred(delete(1:4, last)) === 1:3 + @test @inferred(delete(1:4, (@optic first(_, 2)))) === 3:4 + l = @optic first(_, 2) @test l((1,2,3)) == [1,2] @test delete((1,2,3), l) === (3,) diff --git a/test/test_extensions.jl b/test/test_extensions.jl index 3708a4d6..fca7166e 100644 --- a/test/test_extensions.jl +++ b/test/test_extensions.jl @@ -83,6 +83,9 @@ end # requires ConstructionBase extension: VERSION >= v"1.9-" && @test (@set v.x = 10) === @SVector [10.,2,3] + v = @MVector [1.,2,3] + @test (@set v[1] = 10)::MVector == @MVector [10.,2,3] + @testset "Multi-dynamic indexing" begin two = 2 plusone(x) = x + 1 @@ -106,6 +109,19 @@ end v = @set StaticArrays.normalize(@SVector [10, 0,0]) = @SVector[0,1,0] @test v ≈ @SVector[0,10,0] @test @set(StaticArrays.norm([1,0]) = 20) ≈ [20, 0] + + cmp(a::NamedTuple, b::NamedTuple) = Set(keys(a)) == Set(keys(b)) && NamedTuple{keys(b)}(a) === b + cmp(a::T, b::T) where {T} = a == b + + if VERSION >= v"1.9-" + # require ConstructionBase extension + test_getset_laws(Tuple, SVector(0, 1), ('x', 'y'), (1, 2); cmp=cmp) + test_getset_laws(Tuple, MVector(0, 1), ('x', 'y'), (1, 2); cmp=cmp) + test_getset_laws(NamedTuple{(:x, :y)}, SVector(0, 1), (x='x', y='y'), (x=1, y=2); cmp=cmp) + test_getset_laws(NamedTuple{(:x, :y)}, SVector(0, 1), (y='x', x='y'), (x=1, y=2); cmp=cmp) + end + test_getset_laws(SVector, (0, 1), SVector('x', 'y'), SVector(1, 2); cmp=cmp) + test_getset_laws(MVector, (0, 1), MVector('x', 'y'), MVector(1, 2); cmp=cmp) end diff --git a/test/test_functionlenses.jl b/test/test_functionlenses.jl index ccfd8e29..abaa06fd 100644 --- a/test/test_functionlenses.jl +++ b/test/test_functionlenses.jl @@ -2,10 +2,10 @@ module TestFunctionLenses using Test using Dates using Unitful +using LinearAlgebra: norm using InverseFunctions: inverse using Accessors: test_getset_laws, test_modify_law using Accessors -using StaticArrays: SVector @testset "os" begin @@ -86,11 +86,14 @@ end cmp(a::NamedTuple, b::NamedTuple) = Set(keys(a)) == Set(keys(b)) && NamedTuple{keys(b)}(a) === b cmp(a::T, b::T) where {T} = a == b + + test_getset_laws(Base.splat(=>), (1, 'a'), 'b' => 2, 3 => 'c'; cmp=cmp) + test_getset_laws(Base.splat(Pair), (1, 'a'), 'b' => 2, 3 => 'c'; cmp=cmp) + test_getset_laws(Base.splat(=>), [1, 2], 3 => 2, 3 => 4; cmp=cmp) test_getset_laws(Tuple, (1, 'a'), ('x', 'y'), (1, 2)) test_getset_laws(Tuple, (a=1, b='a'), ('x', 'y'), (1, 2)) test_getset_laws(Tuple, [0, 1], ('x', 'y'), (1, 2); cmp=cmp) - test_getset_laws(Tuple, SVector(0, 1), ('x', 'y'), (1, 2); cmp=cmp) test_getset_laws(Tuple, CartesianIndex(1, 2), (3, 4), (5, 6)) test_getset_laws(NamedTuple{(:x, :y)}, (1, 'a'), (x='x', y='y'), (x=1, y=2); cmp=cmp) @@ -101,8 +104,6 @@ end test_getset_laws(NamedTuple{(:x, :y)}, (y=1, z=10, x='a'), (y='x', x='y'), (x=1, y=2); cmp=cmp) test_getset_laws(NamedTuple{(:x, :y)}, [0, 1], (x='x', y='y'), (x=1, y=2); cmp=cmp) test_getset_laws(NamedTuple{(:x, :y)}, [0, 1], (y='x', x='y'), (x=1, y=2); cmp=cmp) - test_getset_laws(NamedTuple{(:x, :y)}, SVector(0, 1), (x='x', y='y'), (x=1, y=2); cmp=cmp) - test_getset_laws(NamedTuple{(:x, :y)}, SVector(0, 1), (y='x', x='y'), (x=1, y=2); cmp=cmp) test_getset_laws(NamedTuple{(:x, :y)}, CartesianIndex(1, 2), (x=3, y=4), (x=5, y=6); cmp=cmp) test_getset_laws(NamedTuple{(:x, :y)}, CartesianIndex(1, 2), (y=3, x=4), (x=5, y=6); cmp=cmp) @@ -198,6 +199,9 @@ end test_getset_laws(mod2pi, 5.3, 1, 2; cmp=isapprox) test_getset_laws(mod2pi, -5.3, 1, 2; cmp=isapprox) + test_getset_laws(Base.splat(atan), (3, 4), 1, 2) + test_getset_laws(Base.splat(atan), (a=3, b=4), 1, 2) + test_getset_laws(!, true, true, false) @testset for o in [ # invertible lenses below: no need for extensive testing, simply forwarded to InverseFunctions @@ -241,6 +245,12 @@ end f = @set inverse(sin) = myasin @test f(2) == sin(2) @test inverse(f)(0.5) == asin(0.5) + 2π + + @test set([3, 4], norm, 10) == [6, 8] + @test set((3, 4), norm, 10) === (6., 8.) + @test set((a=3, b=4), norm, 10) === (a=6., b=8.) + test_getset_laws(norm, (3, 4), 10, 12) + test_getset_laws(Base.splat(hypot), (3, 4), 10, 12) end @testset "dates" begin @@ -264,6 +274,10 @@ end test_getset_laws(yearmonthday, x, (rand(1:5000), rand(1:12), rand(1:28)), (rand(1:5000), rand(1:12), rand(1:28))) end + @testset for x in [DateTime(2020, 1, 2, 3, 4, 5, 6), Date(2020, 1, 2), Time(1, 2, 3, 4, 5, 6)] + test_getset_laws(Dates.value, x, 123, 456) + end + l = @optic DateTime(_, dateformat"yyyy_mm_dd") @test @inferred(set("2020_03_04", month ∘ l, 10)) == "2020_10_04" test_getset_laws(month ∘ l, "2020_03_04", 10, 11) diff --git a/test/test_insert.jl b/test/test_insert.jl index 796b3923..451baba3 100644 --- a/test/test_insert.jl +++ b/test/test_insert.jl @@ -17,6 +17,7 @@ using Accessors: insert @test insert(A, @optic(last(_, 2)), [3, 4]) == [1, 2, 3, 4] @test A == [1, 2] # not changed end + @test @inferred(insert(CartesianIndex(1, 2, 3), @optic(_[2]), 4)) == CartesianIndex(1, 4, 2, 3) @test insert((1,2), last, 3) == (1, 2, 3) @inferred(insert((1,2), last, 3)) @test @inferred(insert(SVector(1,2), @optic(_[1]), 3)) == SVector(3, 1, 2)