diff --git a/docs/src/calculus.md b/docs/src/calculus.md index 7ba1afa..1b71488 100644 --- a/docs/src/calculus.md +++ b/docs/src/calculus.md @@ -12,8 +12,10 @@ DCAT ```@docs Compose -NonLinearCompose -Hadamard +HadamardProd +Ax_mul_Bx +Axt_mul_Bx +Ax_mul_Bxt ``` ## Transformations diff --git a/src/AbstractOperators.jl b/src/AbstractOperators.jl index db6f601..8580062 100644 --- a/src/AbstractOperators.jl +++ b/src/AbstractOperators.jl @@ -58,7 +58,11 @@ include("calculus/Sum.jl") include("calculus/AffineAdd.jl") include("calculus/Jacobian.jl") include("calculus/NonLinearCompose.jl") +include("calculus/Axt_mul_Bx.jl") +include("calculus/Ax_mul_Bxt.jl") +include("calculus/Ax_mul_Bx.jl") include("calculus/Hadamard.jl") +include("calculus/HadamardProd.jl") # Non-Linear operators include("nonlinearoperators/Pow.jl") diff --git a/src/calculus/Ax_mul_Bx.jl b/src/calculus/Ax_mul_Bx.jl new file mode 100644 index 0000000..46832f9 --- /dev/null +++ b/src/calculus/Ax_mul_Bx.jl @@ -0,0 +1,110 @@ +#Ax_mul_Bx + +export Ax_mul_Bx + +""" +`Ax_mul_Bx(A::AbstractOperator,B::AbstractOperator)` + +Create an operator `P` such that: + +`P*x == (Ax)*(Bx)` + +# Example + +```julia +julia> A,B = randn(4,4),randn(4,4); + +julia> P = Ax_mul_Bx(MatrixOp(A,4),MatrixOp(B,4)) +▒*▒ ℝ^4 -> ℝ^(4, 4) + +julia> X = randn(4,4); + +julia> P*X == (A*X)*(B*X) +true + +``` +""" +struct Ax_mul_Bx{ + L1 <: AbstractOperator, + L2 <: AbstractOperator, + C <: AbstractArray, + D <: AbstractArray, + } <: NonLinearOperator + A::L1 + B::L2 + bufA::C + bufB::C + bufC::C + bufD::D + function Ax_mul_Bx(A::L1, B::L2, bufA::C, bufB::C, bufC::C, bufD::D) where {L1,L2,C,D} + if ndims(A,1) != 2 || size(A,2) != size(B,2) || size(A,1)[2] != size(B,1)[1] + throw(DimensionMismatch("Cannot compose operators")) + end + new{L1,L2,C,D}(A,B,bufA,bufB,bufC,bufD) + end +end + +struct Ax_mul_BxJac{ + L1 <: AbstractOperator, + L2 <: AbstractOperator, + C <: AbstractArray, + D <: AbstractArray, + } <: LinearOperator + A::L1 + B::L2 + bufA::C + bufB::C + bufC::C + bufD::D +end + +# Constructors +function Ax_mul_Bx(A::AbstractOperator,B::AbstractOperator) + s,t = size(A,1), codomainType(A) + bufA = eltype(s) <: Int ? zeros(t,s) : ArrayPartition(zeros.(t,s)...) + s,t = size(B,1), codomainType(B) + bufB = eltype(s) <: Int ? zeros(t,s) : ArrayPartition(zeros.(t,s)...) + bufC = eltype(s) <: Int ? zeros(t,s) : ArrayPartition(zeros.(t,s)...) + s,t = size(A,2), domainType(A) + bufD = eltype(s) <: Int ? zeros(t,s) : ArrayPartition(zeros.(t,s)...) + Ax_mul_Bx(A,B,bufA,bufB,bufC,bufD) +end + +# Jacobian +function Jacobian(P::Ax_mul_Bx{L1,L2,C,D}, x::AbstractArray) where {L1,L2,C,D} + JA, JB = Jacobian(P.A, x), Jacobian(P.B, x) + Ax_mul_BxJac{typeof(JA),typeof(JB),C,D}(JA,JB,P.bufA,P.bufB,P.bufC,P.bufD) +end + +# Mappings +function mul!(y, P::Ax_mul_Bx{L1,L2,C,D}, b) where {L1,L2,C,D} + mul!(P.bufA,P.A,b) + mul!(P.bufB,P.B,b) + mul!(y,P.bufA, P.bufB) +end + +function mul!(y, J::AdjointOperator{Ax_mul_BxJac{L1,L2,C,D}}, b) where {L1,L2,C,D} + #y .= J.A.B' * ( J.A.bufA'*b ) + J.A.A' * ( b*J.A.bufB' ) + mul!(J.A.bufC, J.A.bufA', b) + mul!(y, J.A.B', J.A.bufC) + mul!(J.A.bufA, b, J.A.bufB') + mul!(J.A.bufD, J.A.A', J.A.bufA) + y .+= J.A.bufD + return y +end + +size(P::Union{Ax_mul_Bx,Ax_mul_BxJac}) = ((size(P.A,1)[1],size(P.B,1)[2]),size(P.A,2)) + +fun_name(L::Union{Ax_mul_Bx,Ax_mul_BxJac}) = fun_name(L.A)*"*"*fun_name(L.B) + +domainType(L::Union{Ax_mul_Bx,Ax_mul_BxJac}) = domainType(L.A) +codomainType(L::Union{Ax_mul_Bx,Ax_mul_BxJac}) = codomainType(L.A) + +# utils +function permute(P::Ax_mul_Bx{L1,L2,C,D}, + p::AbstractVector{Int}) where {L1,L2,C,D <:ArrayPartition} + Ax_mul_Bx(permute(P.A,p),permute(P.B,p),P.bufA,P.bufB,P.bufC,ArrayPartition(P.bufD.x[p]) ) +end + +remove_displacement(P::Ax_mul_Bx) = +Ax_mul_Bx(remove_displacement(P.A), remove_displacement(P.B), P.bufA, P.bufB, P.bufC, P.bufD) diff --git a/src/calculus/Ax_mul_Bxt.jl b/src/calculus/Ax_mul_Bxt.jl new file mode 100644 index 0000000..7e0f834 --- /dev/null +++ b/src/calculus/Ax_mul_Bxt.jl @@ -0,0 +1,118 @@ +#Ax_mul_Bxt + +export Ax_mul_Bxt + +""" +`Ax_mul_Bxt(A::AbstractOperator,B::AbstractOperator)` + +Create an operator `P` such that: + +`P == (Ax)*(Bx)'` + +# Example: Matrix multiplication + +```julia +julia> A,B = randn(4,4),randn(4,4); + +julia> P = Ax_mul_Bxt(MatrixOp(A),MatrixOp(B)) +▒*▒ ℝ^4 -> ℝ^(4, 4) + +julia> x = randn(4); + +julia> P*x == (A*x)*(B*x)' +true + +``` +""" +struct Ax_mul_Bxt{ + L1 <: AbstractOperator, + L2 <: AbstractOperator, + C <: AbstractArray, + D <: AbstractArray, + } <: NonLinearOperator + A::L1 + B::L2 + bufA::C + bufB::C + bufC::C + bufD::D + function Ax_mul_Bxt(A::L1, B::L2, bufA::C, bufB::C, bufC::C, bufD::D) where {L1,L2,C,D} + if ndims(A,1) == 1 + if size(A) != size(B) + throw(DimensionMismatch("Cannot compose operators")) + end + elseif ndims(A,1) == 2 && ndims(B,1) == 2 && size(A,2) == size(B,2) + if size(A,1)[2] != size(B,1)[2] + throw(DimensionMismatch("Cannot compose operators")) + end + else + throw(DimensionMismatch("Cannot compose operators")) + end + new{L1,L2,C,D}(A,B,bufA,bufB,bufC,bufD) + end +end + +struct Ax_mul_BxtJac{ + L1 <: AbstractOperator, + L2 <: AbstractOperator, + C <: AbstractArray, + D <: AbstractArray, + } <: LinearOperator + A::L1 + B::L2 + bufA::C + bufB::C + bufC::C + bufD::D +end + +# Constructors +function Ax_mul_Bxt(A::AbstractOperator,B::AbstractOperator) + s,t = size(A,1), codomainType(A) + bufA = eltype(s) <: Int ? zeros(t,s) : ArrayPartition(zeros.(t,s)...) + bufC = eltype(s) <: Int ? zeros(t,s) : ArrayPartition(zeros.(t,s)...) + s,t = size(B,1), codomainType(B) + bufB = eltype(s) <: Int ? zeros(t,s) : ArrayPartition(zeros.(t,s)...) + s,t = size(A,2), domainType(A) + bufD = eltype(s) <: Int ? zeros(t,s) : ArrayPartition(zeros.(t,s)...) + Ax_mul_Bxt(A,B,bufA,bufB,bufC,bufD) +end + +# Jacobian +function Jacobian(P::Ax_mul_Bxt{L1,L2,C,D}, x::AbstractArray) where {L1,L2,C,D} + JA, JB = Jacobian(P.A, x), Jacobian(P.B, x) + Ax_mul_BxtJac{typeof(JA),typeof(JB),C,D}(JA,JB,P.bufA,P.bufB,P.bufC,P.bufD) +end + +# Mappings +function mul!(y, P::Ax_mul_Bxt{L1,L2,C,D}, b) where {L1,L2,C,D} + mul!(P.bufA,P.A,b) + mul!(P.bufB,P.B,b) + mul!(y,P.bufA, P.bufB') +end + +function mul!(y, J::AdjointOperator{Ax_mul_BxtJac{L1,L2,C,D}}, b) where {L1,L2,C,D} + #y .= J.A.A'*(b*(J.A.bufB)) + J.A.B'*(b'*(J.A.bufA)) + mul!(J.A.bufC, b, J.A.bufB) + mul!(y, J.A.A', J.A.bufC) + mul!(J.A.bufB, b', J.A.bufA) + mul!(J.A.bufD, J.A.B', J.A.bufB) + y .+= J.A.bufD + return y +end + +size(P::Union{Ax_mul_Bxt,Ax_mul_BxtJac}) = ((size(P.A,1)[1],size(P.B,1)[1]),size(P.A,2)) + +fun_name(L::Union{Ax_mul_Bxt,Ax_mul_BxtJac}) = fun_name(L.A)*"*"*fun_name(L.B) + +domainType(L::Union{Ax_mul_Bxt,Ax_mul_BxtJac}) = domainType(L.A) +codomainType(L::Union{Ax_mul_Bxt,Ax_mul_BxtJac}) = codomainType(L.A) + +# utils +function permute(P::Ax_mul_Bxt{L1,L2,C,D}, + p::AbstractVector{Int}) where {L1,L2,C,D <:ArrayPartition} + Ax_mul_Bxt(permute(P.A,p),permute(P.B,p),P.bufA,P.bufB,P.bufC,ArrayPartition(P.bufD.x[p]) ) +end + +remove_displacement(P::Ax_mul_Bxt) = +Ax_mul_Bxt(remove_displacement(P.A), remove_displacement(P.B), P.bufA, P.bufB, P.bufC, P.bufD) diff --git a/src/calculus/Axt_mul_Bx.jl b/src/calculus/Axt_mul_Bx.jl new file mode 100644 index 0000000..2a951b3 --- /dev/null +++ b/src/calculus/Axt_mul_Bx.jl @@ -0,0 +1,137 @@ +#Axt_mul_Bx + +export Axt_mul_Bx + +""" +`Axt_mul_Bx(A::AbstractOperator,B::AbstractOperator)` + +Create an operator `P` such that: + +`P*x == (Ax)'*(Bx)` + +# Example + +```julia +julia> A,B = randn(4,4),randn(4,4); + +julia> P = Axt_mul_Bx(MatrixOp(A),MatrixOp(B)) +▒*▒ ℝ^4 -> ℝ^1 + +julia> x = randn(4); + +julia> P*x == [(A*x)'*(B*x)] +true + +``` +""" +struct Axt_mul_Bx{N, + L1 <: AbstractOperator, + L2 <: AbstractOperator, + C <: AbstractArray, + D <: AbstractArray, + } <: NonLinearOperator + A::L1 + B::L2 + bufA::C + bufB::C + bufC::C + bufD::D + function Axt_mul_Bx(A::L1, B::L2, bufA::C, bufB::C, bufC::C, bufD::D) where {L1,L2,C,D} + if ndims(A,1) == 1 + if size(A) != size(B) + throw(DimensionMismatch("Cannot compose operators")) + end + elseif ndims(A,1) == 2 && ndims(B,1) == 2 && size(A,2) == size(B,2) + if size(A,1)[1] != size(B,1)[1] + throw(DimensionMismatch("Cannot compose operators")) + end + else + throw(DimensionMismatch("Cannot compose operators")) + end + N = ndims(A,1) + new{N,L1,L2,C,D}(A,B,bufA,bufB,bufC,bufD) + end +end + +struct Axt_mul_BxJac{N, + L1 <: AbstractOperator, + L2 <: AbstractOperator, + C <: AbstractArray, + D <: AbstractArray, + } <: LinearOperator + A::L1 + B::L2 + bufA::C + bufB::C + bufC::C + bufD::D +end + +# Constructors +function Axt_mul_Bx(A::AbstractOperator,B::AbstractOperator) + s,t = size(A,1), codomainType(A) + bufA = eltype(s) <: Int ? zeros(t,s) : ArrayPartition(zeros.(t,s)...) + bufC = eltype(s) <: Int ? zeros(t,s) : ArrayPartition(zeros.(t,s)...) + s,t = size(B,1), codomainType(B) + bufB = eltype(s) <: Int ? zeros(t,s) : ArrayPartition(zeros.(t,s)...) + s,t = size(A,2), domainType(A) + bufD = eltype(s) <: Int ? zeros(t,s) : ArrayPartition(zeros.(t,s)...) + Axt_mul_Bx(A,B,bufA,bufB,bufC,bufD) +end + +# Jacobian +function Jacobian(P::Axt_mul_Bx{N,L1,L2,C,D}, x::AbstractArray) where {N,L1,L2,C,D} + JA, JB = Jacobian(P.A, x), Jacobian(P.B, x) + Axt_mul_BxJac{N,typeof(JA),typeof(JB),C,D}(JA,JB,P.bufA,P.bufB,P.bufC,P.bufD) +end + +# Mappings +# N == 1 input is a vector +function mul!(y, P::Axt_mul_Bx{1,L1,L2,C,D}, b) where {L1,L2,C,D} + mul!(P.bufA,P.A,b) + mul!(P.bufB,P.B,b) + y[1] = dot(P.bufA,P.bufB) +end + +function mul!(y, J::AdjointOperator{Axt_mul_BxJac{1,L1,L2,C,D}}, b) where {L1,L2,C,D} + #y .= conj(J.A.A'*J.A.bufB+J.A.B'*J.A.bufA).*b[1] + mul!(y, J.A.A', J.A.bufB) + mul!(J.A.bufD, J.A.B', J.A.bufA) + y .= conj.( y .+ J.A.bufD ) .* b[1] + return y +end + +# N == 2 input is a matrix +function mul!(y, P::Axt_mul_Bx{2,L1,L2,C,D}, b) where {L1,L2,C,D} + mul!(P.bufA,P.A,b) + mul!(P.bufB,P.B,b) + mul!(y,P.bufA',P.bufB) + return y +end + +function mul!(y, J::AdjointOperator{Axt_mul_BxJac{2,L1,L2,C,D}}, b) where {L1,L2,C,D} + # y .= J.A.A'*((J.A.bufB)*b') + J.A.B'*((J.A.bufA)*b) + mul!(J.A.bufC, J.A.bufB, b') + mul!(y, J.A.A', J.A.bufC) + mul!(J.A.bufB, J.A.bufA, b) + mul!(J.A.bufD, J.A.B', J.A.bufB) + y .+= J.A.bufD + return y +end + +size(P::Union{Axt_mul_Bx{1},Axt_mul_BxJac{1}}) = ((1,),size(P.A,2)) +size(P::Union{Axt_mul_Bx{2},Axt_mul_BxJac{2}}) = ((size(P.A,1)[2],size(P.B,1)[2]),size(P.A,2)) + +fun_name(L::Union{Axt_mul_Bx,Axt_mul_BxJac}) = fun_name(L.A)*"*"*fun_name(L.B) + +domainType(L::Union{Axt_mul_Bx,Axt_mul_BxJac}) = domainType(L.A) +codomainType(L::Union{Axt_mul_Bx,Axt_mul_BxJac}) = codomainType(L.A) + +# utils +function permute(P::Axt_mul_Bx{N,L1,L2,C,D}, + p::AbstractVector{Int}) where {N,L1,L2,C,D <:ArrayPartition} + Axt_mul_Bx(permute(P.A,p),permute(P.B,p),P.bufA,P.bufB,P.bufC,ArrayPartition(P.bufD.x[p]) ) +end + +remove_displacement(P::Axt_mul_Bx) = +Axt_mul_Bx(remove_displacement(P.A), remove_displacement(P.B), P.bufA, P.bufB, P.bufC, P.bufD) diff --git a/src/calculus/Hadamard.jl b/src/calculus/Hadamard.jl index 93e95bd..76b01d7 100644 --- a/src/calculus/Hadamard.jl +++ b/src/calculus/Hadamard.jl @@ -28,47 +28,47 @@ true ``` """ struct Hadamard{C, V <: VCAT} <: NonLinearOperator - A::V - buf::C - buf2::C - function Hadamard(A::V, buf::C, buf2::C) where {C, V <: VCAT} - any([ai != size(A,1)[1] for ai in size(A,1)]) && - throw(DimensionMismatch("cannot compose operators")) - - new{C, V}(A,buf,buf2) - end + A::V + buf::C + buf2::C + function Hadamard(A::V, buf::C, buf2::C) where {C, V <: VCAT} + any([ai != size(A,1)[1] for ai in size(A,1)]) && + throw(DimensionMismatch("cannot compose operators")) + @warn "`Hadamard` will be substituted by `HadamardProd` in future versions of AbstractOperators" + new{C, V}(A,buf,buf2) + end end struct HadamardJacobian{C, V <: VCAT} <: LinearOperator - A::V - buf::C - buf2::C - function HadamardJacobian(A::V,buf::C,buf2::C) where {C, V <: VCAT} - new{C, V}(A,buf,buf2) - end + A::V + buf::C + buf2::C + function HadamardJacobian(A::V,buf::C,buf2::C) where {C, V <: VCAT} + new{C, V}(A,buf,buf2) + end end # Constructors function Hadamard(L1::AbstractOperator,L2::AbstractOperator) - A = HCAT(L1, Zeros( domainType(L2), size(L2,2), codomainType(L1), size(L1,1) )) - B = HCAT(Zeros( domainType(L1), size(L1,2), codomainType(L2), size(L2,1) ), L2 ) + A = HCAT(L1, Zeros( domainType(L2), size(L2,2), codomainType(L1), size(L1,1) )) + B = HCAT(Zeros( domainType(L1), size(L1,2), codomainType(L2), size(L2,1) ), L2 ) V = VCAT(A,B) buf = ArrayPartition(zeros.(codomainType(V), size(V,1))) buf2 = ArrayPartition(zeros.(codomainType(V), size(V,1))) - Hadamard(V,buf,buf2) + Hadamard(V,buf,buf2) end # Mappings function mul!(y, H::Hadamard{C,V}, b::ArrayPartition) where {C,V} - mul!(H.buf,H.A,b) - y .= H.buf.x[1] + mul!(H.buf,H.A,b) + y .= H.buf.x[1] for i = 2:length(H.buf.x) y .*= H.buf.x[i] - end + end end # Jacobian @@ -79,8 +79,8 @@ function mul!(y::ArrayPartition, A::AdjointOperator{HadamardJacobian{C,V}}, b) w J = A.A for i = 1:length(J.buf.x) J.buf2.x[i] .= (.*)(J.buf.x[1:i-1]...,J.buf.x[i+1:end]...,b) - end - mul!(y, J.A', J.buf2) + end + mul!(y, J.A', J.buf2) end # Properties @@ -98,8 +98,8 @@ codomainType(L::HadamardJacobian) = codomainType(L.A[1]) # utils function permute(H::Hadamard, p::AbstractVector{Int}) - A = VCAT([permute(a,p) for a in H.A.A]...) - Hadamard(A,H.buf,H.buf2) + A = VCAT([permute(a,p) for a in H.A.A]...) + Hadamard(A,H.buf,H.buf2) end remove_displacement(N::Hadamard) = Hadamard(remove_displacement(N.A), N.buf, N.buf2) diff --git a/src/calculus/HadamardProd.jl b/src/calculus/HadamardProd.jl new file mode 100644 index 0000000..b3e8ad1 --- /dev/null +++ b/src/calculus/HadamardProd.jl @@ -0,0 +1,109 @@ +#HadamardProd + +export HadamardProd + +""" +`HadamardProd(A::AbstractOperator,B::AbstractOperator)` + +Create an operator `P` such that: + +`P*x == (Ax).*(Bx)` + +# Example + +```julia +julia> A,B = Sin(3), Cos(3); + +julia> P = HadamardProd(A,B) +sin.*cos ℝ^3 -> ℝ^3 + +julia> x = randn(3); + +julia> P*x == (sin.(x).*cos.(x)) +true + + +``` +""" +struct HadamardProd{ + L1 <: AbstractOperator, + L2 <: AbstractOperator, + C <: AbstractArray, + D <: AbstractArray, + } <: NonLinearOperator + A::L1 + B::L2 + bufA::C + bufB::C + bufD::D + function HadamardProd(A::L1, B::L2, bufA::C, bufB::C, bufD::D) where {L1,L2,C,D} + if size(A) != size(B) + throw(DimensionMismatch("Cannot compose operators")) + end + new{L1,L2,C,D}(A,B,bufA,bufB,bufD) + end +end + +struct HadamardProdJac{ + L1 <: AbstractOperator, + L2 <: AbstractOperator, + C <: AbstractArray, + D <: AbstractArray, + } <: LinearOperator + A::L1 + B::L2 + bufA::C + bufB::C + bufD::D +end + +# Constructors +function HadamardProd(A::AbstractOperator,B::AbstractOperator) + s,t = size(A,1), codomainType(A) + bufA = eltype(s) <: Int ? zeros(t,s) : ArrayPartition(zeros.(t,s)...) + s,t = size(B,1), codomainType(B) + bufB = eltype(s) <: Int ? zeros(t,s) : ArrayPartition(zeros.(t,s)...) + s,t = size(A,2), domainType(A) + bufD = eltype(s) <: Int ? zeros(t,s) : ArrayPartition(zeros.(t,s)...) + HadamardProd(A,B,bufA,bufB,bufD) +end + +# Jacobian +function Jacobian(P::HadamardProd{L1,L2,C,D}, x::AbstractArray) where {L1,L2,C,D} + JA, JB = Jacobian(P.A, x), Jacobian(P.B, x) + HadamardProdJac{typeof(JA),typeof(JB),C,D}(JA,JB,P.bufA,P.bufB,P.bufD) +end + +# Mappings +function mul!(y, P::HadamardProd{L1,L2,C,D}, b) where {L1,L2,C,D} + mul!(P.bufA,P.A,b) + mul!(P.bufB,P.B,b) + y .= P.bufA .* P.bufB + return y +end + +function mul!(y, J::AdjointOperator{HadamardProdJac{L1,L2,C,D}}, b) where {L1,L2,C,D} + #y .= J.A.B' * ( J.A.bufA .*b ) + J.A.A' * ( J.A.bufB .* b ) + J.A.bufA .*= b + mul!(y, J.A.B', J.A.bufA) + J.A.bufB .*= b + mul!(J.A.bufD, J.A.A', J.A.bufB) + y .+= J.A.bufD + return y +end + +size(P::Union{HadamardProd,HadamardProdJac}) = (size(P.A,1),size(P.A,2)) + +fun_name(L::Union{HadamardProd,HadamardProdJac}) = fun_name(L.A)*".*"*fun_name(L.B) + +domainType(L::Union{HadamardProd,HadamardProdJac}) = domainType(L.A) +codomainType(L::Union{HadamardProd,HadamardProdJac}) = codomainType(L.A) + +# utils +function permute(P::HadamardProd{L1,L2,C,D}, + p::AbstractVector{Int}) where {L1,L2,C,D <:ArrayPartition} + HadamardProd(permute(P.A,p),permute(P.B,p),P.bufA,P.bufB,ArrayPartition(P.bufD.x[p]) ) +end + +remove_displacement(P::HadamardProd) = +HadamardProd(remove_displacement(P.A), remove_displacement(P.B), P.bufA, P.bufB, P.bufD) diff --git a/src/calculus/NonLinearCompose.jl b/src/calculus/NonLinearCompose.jl index 7a47df0..01b1d0d 100644 --- a/src/calculus/NonLinearCompose.jl +++ b/src/calculus/NonLinearCompose.jl @@ -5,9 +5,11 @@ export NonLinearCompose """ `NonLinearCompose(A::AbstractOperator,B::AbstractOperator)` -Compose opeators in such fashion: +Compose opeators such that: -`A(⋅)*B(⋅)` +`A(x1)*B(x2)` + +where `x1` and `x2` are two independent inputs. # Example: Matrix multiplication @@ -27,70 +29,71 @@ true ``` """ struct NonLinearCompose{ - L1 <: HCAT, - L2 <: HCAT, - C <: AbstractArray, - D <: AbstractArray - } <: NonLinearOperator - A::L1 - B::L2 - buf::C - bufx::D - function NonLinearCompose(A::L1, B::L2, buf::C, bufx::D) where {L1,L2,C,D} - if ( (ndoms(A,1) > 1 || ndoms(B,1) > 1) || - (ndims(A,1) > 2 || ndims(B,1) > 2) || - (size(B,1)[1] == 1 ? (length(size(A,1)) == 1 ? false : true) : # outer product case - ndims(A,1) == 1 ? true : (size(A,1)[2] != size(B,1)[1])) - ) - throw(DimensionMismatch("cannot compose operators")) - end - new{L1,L2,C,D}(A,B,buf,bufx) - end + L1 <: HCAT, + L2 <: HCAT, + C <: AbstractArray, + D <: AbstractArray + } <: NonLinearOperator + A::L1 + B::L2 + buf::C + bufx::D + function NonLinearCompose(A::L1, B::L2, buf::C, bufx::D) where {L1,L2,C,D} + if ( (ndoms(A,1) > 1 || ndoms(B,1) > 1) || + (ndims(A,1) > 2 || ndims(B,1) > 2) || + (size(B,1)[1] == 1 ? (length(size(A,1)) == 1 ? false : true) : # outer product case + ndims(A,1) == 1 ? true : (size(A,1)[2] != size(B,1)[1])) + ) + throw(DimensionMismatch("cannot compose operators")) + end + @warn "`NonLinearCompose` will be substituted by `Ax_mul_Bx` in future versions of AbstractOperators" + new{L1,L2,C,D}(A,B,buf,bufx) + end end struct NonLinearComposeJac{ - L1 <: HCAT, - L2 <: HCAT, - C <: AbstractArray, - D <: AbstractArray - } <: LinearOperator - A::L1 - B::L2 - buf::C - bufx::D + L1 <: HCAT, + L2 <: HCAT, + C <: AbstractArray, + D <: AbstractArray + } <: LinearOperator + A::L1 + B::L2 + buf::C + bufx::D end # Constructors function NonLinearCompose(L1::AbstractOperator,L2::AbstractOperator) - A = HCAT(L1, Zeros( domainType(L2), size(L2,2), codomainType(L1), size(L1,1) )) - B = HCAT(Zeros( domainType(L1), size(L1,2), codomainType(L2), size(L2,1) ), L2 ) + A = HCAT(L1, Zeros( domainType(L2), size(L2,2), codomainType(L1), size(L1,1) )) + B = HCAT(Zeros( domainType(L1), size(L1,2), codomainType(L2), size(L2,1) ), L2 ) buf = ArrayPartition(zeros(codomainType(A),size(A,1)), zeros(codomainType(B),size(B,1))) bufx = ArrayPartition(zeros(codomainType(L1),size(L1,1)), zeros(codomainType(L2),size(L2,1))) - NonLinearCompose(A,B,buf,bufx) + NonLinearCompose(A,B,buf,bufx) end # Jacobian function Jacobian(P::NonLinearCompose{L1,L2,C,D}, x::AbstractArray) where {L1,L2,C,D} - NonLinearComposeJac(Jacobian(P.A,x),Jacobian(P.B,x),P.buf,P.bufx) + NonLinearComposeJac(Jacobian(P.A,x),Jacobian(P.B,x),P.buf,P.bufx) end # Mappings function mul!(y, P::NonLinearCompose{L1,L2,C,D}, b) where {L1,L2,C,D} - mul_skipZeros!(P.buf.x[1],P.A,b) - mul_skipZeros!(P.buf.x[2],P.B,b) - mul!(y,P.buf.x[1],P.buf.x[2]) + mul_skipZeros!(P.buf.x[1],P.A,b) + mul_skipZeros!(P.buf.x[2],P.B,b) + mul!(y,P.buf.x[1],P.buf.x[2]) end function mul!(y, J::AdjointOperator{NonLinearComposeJac{L1,L2,C,D}}, b) where {L1,L2,C,D} P = J.A - mul!(P.bufx.x[1],b,P.buf.x[2]') - mul_skipZeros!(y,P.A',P.bufx.x[1]) + mul!(P.bufx.x[1],b,P.buf.x[2]') + mul_skipZeros!(y,P.A',P.bufx.x[1]) - mul!(P.bufx.x[2],P.buf.x[1]',b) - mul_skipZeros!(y,P.B',P.bufx.x[2]) + mul!(P.bufx.x[2],P.buf.x[1]',b) + mul_skipZeros!(y,P.B',P.bufx.x[2]) end # special case outer product @@ -105,21 +108,21 @@ function mul!(y, mul!(p,b,P.buf.x[2]') mul_skipZeros!(y,P.A',P.bufx.x[1]) - mul!(P.bufx.x[2],P.buf.x[1]',b) - mul_skipZeros!(y,P.B',P.bufx.x[2]) + mul!(P.bufx.x[2],P.buf.x[1]',b) + mul_skipZeros!(y,P.B',P.bufx.x[2]) end # Properties function size(P::NonLinearCompose) - size_out = ndims(P.B,1) == 1 ? (size(P.A,1)[1],) : - (size(P.A,1)[1], size(P.B,1)[2]) - size_out, size(P.A,2) + size_out = ndims(P.B,1) == 1 ? (size(P.A,1)[1],) : + (size(P.A,1)[1], size(P.B,1)[2]) + size_out, size(P.A,2) end function size(P::NonLinearComposeJac) - size_out = ndims(P.B,1) == 1 ? (size(P.A,1)[1],) : - (size(P.A,1)[1], size(P.B,1)[2]) - size_out, size(P.A,2) + size_out = ndims(P.B,1) == 1 ? (size(P.A,1)[1],) : + (size(P.A,1)[1], size(P.B,1)[2]) + size_out, size(P.A,2) end fun_name(L::NonLinearCompose) = fun_name(L.A.A[1])*"*"*fun_name(L.B.A[2]) @@ -133,7 +136,7 @@ codomainType(L::NonLinearComposeJac) = codomainType(L.A) # utils function permute(P::NonLinearCompose{L,C,D}, p::AbstractVector{Int}) where {L,C,D} - NonLinearCompose(permute(P.A,p),permute(P.B,p),P.buf,P.bufx) + NonLinearCompose(permute(P.A,p),permute(P.B,p),P.buf,P.bufx) end remove_displacement(N::NonLinearCompose) = NonLinearCompose(remove_displacement(N.A), remove_displacement(N.B), N.buf, N.bufx) diff --git a/src/calculus/Sum.jl b/src/calculus/Sum.jl index 91d4cd0..d8db14d 100644 --- a/src/calculus/Sum.jl +++ b/src/calculus/Sum.jl @@ -4,15 +4,15 @@ struct Sum{K, C <: AbstractArray, D <: AbstractArray, L <:NTuple{K,AbstractOperator}} <: AbstractOperator - A::L - bufC::C - bufD::D + A::L + bufC::C + bufD::D function Sum(A::L, bufC::C, bufD::D) where {C, D, K, L <: NTuple{K,AbstractOperator}} if any([size(a) != size(A[1]) for a in A]) throw(DimensionMismatch("cannot sum operator of different sizes")) end if any([codomainType(A[1]) != codomainType(a) for a in A]) || - any([ domainType(A[1]) != domainType(a) for a in A]) + any([ domainType(A[1]) != domainType(a) for a in A]) throw(DomainError(A,"cannot sum operator with different codomains")) end new{K, C, D, L}(A, bufC, bufD) @@ -22,15 +22,15 @@ end Sum(A::AbstractOperator) = A function Sum(A::Vararg{AbstractOperator}) - s = size(A[1],1) - t = codomainType(A[1]) + s = size(A[1],1) + t = codomainType(A[1]) bufC = eltype(s) <: Int ? zeros(t,s) : ArrayPartition(zeros.(t,s)...) - s = size(A[1],2) - t = domainType(A[1]) + s = size(A[1],2) + t = domainType(A[1]) bufD = eltype(s) <: Int ? zeros(t,s) : ArrayPartition(zeros.(t,s)...) - return Sum(A, bufC, bufD) + return Sum(A, bufC, bufD) end # special cases @@ -40,41 +40,41 @@ Sum((L1,L2.A...), L2.bufC, L2.bufD) # Mappings @generated function mul!(y::C, S::Sum{K,C,D}, b::D) where {K,C,D} - ex = :(mul!(y,S.A[1],b)) - for i = 2:K - ex = quote - $ex - mul!(S.bufC,S.A[$i],b) - end + ex = :(mul!(y,S.A[1],b)) + for i = 2:K + ex = quote + $ex + mul!(S.bufC,S.A[$i],b) + end ex = :($ex; y .+= S.bufC) - end - ex = quote - $ex - return y - end + end + ex = quote + $ex + return y + end end @generated function mul!(y::D, A::AdjointOperator{Sum{K,C,D,L}}, b::C) where {K,C,D,L} - ex = :(S = A.A; mul!(y,S.A[1]',b)) - for i = 2:K - ex = quote - $ex - mul!(S.bufD, S.A[$i]', b) - end - ex = :($ex; y .+= S.bufD) - end - ex = quote - $ex - return y - end + ex = :(S = A.A; mul!(y,S.A[1]',b)) + for i = 2:K + ex = quote + $ex + mul!(S.bufD, S.A[$i]', b) + end + ex = :($ex; y .+= S.bufD) + end + ex = quote + $ex + return y + end end # Properties size(L::Sum) = size(L.A[1]) - domainType(S::Sum{K, C, D, L}) where {K,C,D<:AbstractArray,L} = domainType(S.A[1]) - domainType(S::Sum{K, C, D, L}) where {K,C,D<:Tuple ,L} = domainType.(Ref(S.A[1])) +domainType(S::Sum{K, C, D, L}) where {K,C,D<:AbstractArray,L} = domainType(S.A[1]) +domainType(S::Sum{K, C, D, L}) where {K,C,D<:Tuple ,L} = domainType.(Ref(S.A[1])) codomainType(S::Sum{K, C, D, L}) where {K,C<:AbstractArray,D,L} = codomainType(S.A[1]) codomainType(S::Sum{K, C, D, L}) where {K,C<:Tuple ,D,L} = codomainType.(Ref(S.A[1])) @@ -95,8 +95,8 @@ diag(L::Sum) = (+).(diag.(L.A)...,) # utils function permute(S::Sum, p::AbstractVector{Int}) - AA = ([permute(A,p) for A in S.A]...,) - return Sum(AA, S.bufC, ArrayPartition(S.bufD.x[p]...)) + AA = ([permute(A,p) for A in S.A]...,) + return Sum(AA, S.bufC, ArrayPartition(S.bufD.x[p]...)) end remove_displacement(S::Sum) = Sum(remove_displacement.(S.A), S.bufC, S.bufD) diff --git a/src/linearoperators/LBFGS.jl b/src/linearoperators/LBFGS.jl index b155523..4b6064a 100644 --- a/src/linearoperators/LBFGS.jl +++ b/src/linearoperators/LBFGS.jl @@ -23,28 +23,28 @@ Use `reset!(L)` to cancel the memory of the operator. """ mutable struct LBFGS{R, T <: AbstractArray, M, I <: Integer} <: LinearOperator - currmem::I - curridx::I - s::T - y::T - s_M::Array{T, 1} - y_M::Array{T, 1} - ys_M::Array{R, 1} - alphas::Array{R, 1} - H::R + currmem::I + curridx::I + s::T + y::T + s_M::Array{T, 1} + y_M::Array{T, 1} + ys_M::Array{R, 1} + alphas::Array{R, 1} + H::R end #default constructor function LBFGS(x::T, M::I) where {T <: AbstractArray, I <: Integer} - s_M = [zero(x) for i = 1:M] - y_M = [zero(x) for i = 1:M] - s = zero(x) - y = zero(x) + s_M = [zero(x) for i = 1:M] + y_M = [zero(x) for i = 1:M] + s = zero(x) + y = zero(x) R = real(eltype(x)) - ys_M = zeros(R, M) - alphas = zeros(R, M) - LBFGS{R, T, M, I}(0, 0, s, y, s_M, y_M, ys_M, alphas, one(R)) + ys_M = zeros(R, M) + alphas = zeros(R, M) + LBFGS{R, T, M, I}(0, 0, s, y, s_M, y_M, ys_M, alphas, one(R)) end """ @@ -53,20 +53,20 @@ end See the documentation for `LBFGS`. """ function update!(L::LBFGS{R, T, M, I}, x::T, x_prev::T, gradx::T, gradx_prev::T) where {R, T, M, I} - L.s .= x .- x_prev - L.y .= gradx .- gradx_prev - ys = real(dot(L.s, L.y)) - if ys > 0 - L.curridx += 1 - if L.curridx > M L.curridx = 1 end - L.currmem += 1 - if L.currmem > M L.currmem = M end - L.ys_M[L.curridx] = ys - L.s_M[L.curridx] .= L.s - L.y_M[L.curridx] .= L.y - yty = real(dot(L.y, L.y)) - L.H = ys/yty - end + L.s .= x .- x_prev + L.y .= gradx .- gradx_prev + ys = real(dot(L.s, L.y)) + if ys > 0 + L.curridx += 1 + if L.curridx > M L.curridx = 1 end + L.currmem += 1 + if L.currmem > M L.currmem = M end + L.ys_M[L.curridx] = ys + L.s_M[L.curridx] .= L.s + L.y_M[L.curridx] .= L.y + yty = real(dot(L.y, L.y)) + L.H = ys/yty + end return L end @@ -76,8 +76,8 @@ end Cancels the memory of `L`. """ function reset!(L::LBFGS) - L.currmem, L.curridx = 0, 0 - L.H = 1.0 + L.currmem, L.curridx = 0, 0 + L.H = 1.0 end # LBFGS operators are symmetric @@ -87,32 +87,32 @@ mul!(x::T, L::AdjointOperator{LBFGS{R, T, M, I}}, y::T) where {R, T, M, I} = mul # Two-loop recursion function mul!(d::T, L::LBFGS{R, T, M, I}, gradx::T) where {R, T, M, I} - d .= gradx - idx = loop1!(d,L) - d .*= L.H - loop2!(d,idx,L) + d .= gradx + idx = loop1!(d,L) + d .*= L.H + loop2!(d,idx,L) return d end function loop1!(d::T, L::LBFGS{R, T, M, I}) where {R, T, M, I} - idx = L.curridx - for i = 1:L.currmem - L.alphas[idx] = real(dot(L.s_M[idx], d))/L.ys_M[idx] - d .-= L.alphas[idx] .* L.y_M[idx] - idx -= 1 - if idx == 0 idx = M end - end - return idx + idx = L.curridx + for i = 1:L.currmem + L.alphas[idx] = real(dot(L.s_M[idx], d))/L.ys_M[idx] + d .-= L.alphas[idx] .* L.y_M[idx] + idx -= 1 + if idx == 0 idx = M end + end + return idx end function loop2!(d::T, idx, L::LBFGS{R, T, M, I}) where {R, T, M, I} - for i = 1:L.currmem - idx += 1 - if idx > M idx = 1 end - beta = real(dot(L.y_M[idx], d))/L.ys_M[idx] - d .+= (L.alphas[idx] - beta) .* L.s_M[idx] - end - return d + for i = 1:L.currmem + idx += 1 + if idx > M idx = 1 end + beta = real(dot(L.y_M[idx], d))/L.ys_M[idx] + d .+= (L.alphas[idx] - beta) .* L.s_M[idx] + end + return d end # Properties diff --git a/test/test_nonlinear_operators.jl b/test/test_nonlinear_operators.jl index ea64163..bee3fc1 100644 --- a/test/test_nonlinear_operators.jl +++ b/test/test_nonlinear_operators.jl @@ -9,9 +9,9 @@ op = Sigmoid(Float64,(n,),2) y, grad = test_NLop(op,x,r,verb) n,m,l = 4,5,6 -x = randn(n,m,l) -r = randn(n,m,l) -op = Sigmoid((n,m,l),2) +x = randn(n,m) +r = randn(n,m) +op = Sigmoid((n,m),2) y, grad = test_NLop(op,x,r,verb) @@ -61,7 +61,7 @@ op = Sin(Float64,(n,m,l)) y, grad = test_NLop(op,x,r,verb) - Cos +## Cos n,m,l = 4,5,6 x = randn(n,m,l) r = randn(n,m,l) diff --git a/test/test_nonlinear_operators_calculus.jl b/test/test_nonlinear_operators_calculus.jl index fc150ab..c18d1b4 100644 --- a/test/test_nonlinear_operators_calculus.jl +++ b/test/test_nonlinear_operators_calculus.jl @@ -565,4 +565,269 @@ y, grad = test_NLop(opP,xp,r,verb) Y = (A*x.x[1]+d1)*(B*x.x[2]+d2) @test norm(Y[:] - y) <= 1e-12 +#### Axt_mul_Bx +n = 10 +A,B = Eye(n),Sin(n) +P = Axt_mul_Bx(A,B) + +x = randn(n) +r = randn(1) +y, grad = test_NLop(P,x,r,verb) +@test norm([(A*x)'*(B*x)]-y) < 1e-8 + +n,m = 3,4 +A,B = MatrixOp(randn(n,m)), MatrixOp(randn(n,m)) +P = Axt_mul_Bx(A,B) + +x = randn(m) +r = randn(1) +y, grad = test_NLop(P,x,r,verb) +@test norm([(A*x)'*(B*x)]-y) < 1e-8 + +n,m,l = 3,7,5 +A,B = MatrixOp(randn(n,m),l), MatrixOp(randn(n,m),l) +P = Axt_mul_Bx(A,B) +x = randn(m,l) +r = randn(l,l) +y, grad = test_NLop(P,x,r,verb) +@test norm((A*x)'*(B*x)-y) < 1e-8 + +n,m = 3,7 +A,B = Sin(n,m), Cos(n,m) +P = Axt_mul_Bx(A,B) +x = randn(n,m) +r = randn(m,m) +y, grad = test_NLop(P,x,r,verb) +@test norm((A*x)'*(B*x)-y) < 1e-8 + +# testing with HCAT +m,n = 3,5 +x = ArrayPartition(randn(m),randn(n)) +r = randn(1) +b = randn(m) +A = AffineAdd(Sin(Float64,(m,)),b) +B = MatrixOp(randn(m,n)) +op1 = HCAT(A,B) +C = Cos(Float64,(m,)) +D = MatrixOp(randn(m,n)) +op2 = HCAT(C,D) +P = Axt_mul_Bx(op1,op2) +y, grad = test_NLop(P,x,r,verb) +@test norm([(op1*x)'*(op2*x)]-y) < 1e-8 + +#test remove_displacement +y2, grad = test_NLop(remove_displacement(P),x,r,verb) +@test norm([(op1*x-b)'*(op2*x)]-y2) < 1e-8 + +# test permute +p = [2,1] +Pp = AbstractOperators.permute(P,p) +xp = ArrayPartition(x.x[p]) +y2, grad = test_NLop(Pp,xp,r,verb) +@test norm(y2-y) < 1e-8 + +@test_throws Exception Axt_mul_Bx(Eye(2,2), Eye(2,1)) +@test_throws Exception Axt_mul_Bx(Eye(2,2,2), Eye(2,2,2)) + +## Ax_mul_Bxt +n = 10 +A,B = Eye(n),Sin(n) +P = Ax_mul_Bxt(A,B) +x = randn(n) +r = randn(n,n) +y, grad = test_NLop(P,x,r,verb) +@test norm((A*x)*(B*x)'-y) < 1e-9 + +n,m = 3,4 +A,B = MatrixOp(randn(n,m)), MatrixOp(randn(n,m)) +P = Ax_mul_Bxt(A,B) +x = randn(m) +r = randn(n,n) +y, grad = test_NLop(P,x,r,verb) +@test norm((A*x)*(B*x)'-y) < 1e-8 + +n,m,l = 3,7,5 +A,B = MatrixOp(randn(n,m),l), MatrixOp(randn(n,m),l) +P = Ax_mul_Bxt(A,B) +x = randn(m,l) +r = randn(n,n) +y, grad = test_NLop(P,x,r,verb) +@test norm((A*x)*(B*x)'-y) < 1e-8 + +n,m = 3,7 +A,B = Sin(n,m), Cos(n,m) +P = Ax_mul_Bxt(A,B) +x = randn(n,m) +r = randn(n,n) +y, grad = test_NLop(P,x,r,verb) +@test norm((A*x)*(B*x)'-y) < 1e-8 + +# testing with HCAT +m,n = 3,5 +x = ArrayPartition(randn(m),randn(n)) +r = randn(m,m) +b = randn(m) +A = AffineAdd(Sin(Float64,(m,)),b) +B = MatrixOp(randn(m,n)) +op1 = HCAT(A,B) +C = Cos(Float64,(m,)) +D = MatrixOp(randn(m,n)) +op2 = HCAT(C,D) +P = Ax_mul_Bxt(op1,op2) +y, grad = test_NLop(P,x,r,verb) +@test norm((op1*x)*(op2*x)'-y) < 1e-8 + +#test remove_displacement +y2, grad = test_NLop(remove_displacement(P),x,r,verb) +@test norm((op1*x-b)*(op2*x)'-y2) < 1e-8 + +# test permute +p = [2,1] +Pp = AbstractOperators.permute(P,p) +xp = ArrayPartition(x.x[p]) +y2, grad = test_NLop(Pp,xp,r,verb) +@test norm(y2-y) < 1e-8 + +@test_throws Exception Ax_mul_Bxt(Eye(2,2), Eye(2,1)) +@test_throws Exception Ax_mul_Bxt(Eye(2,2,2), Eye(2,2,2)) + +## Ax_mul_Bx + +n = 3 +A,B = Eye(n,n), Eye(n,n) +P = Ax_mul_Bx(A,B) +x = randn(n,n) +r = randn(n,n) +y, grad = test_NLop(P,x,r,verb) +@test norm(x*x-y) < 1e-9 + +n = 3 +A,B = Sin(n,n), Cos(n,n) +P = Ax_mul_Bx(A,B) +x = randn(n,n) +r = randn(n,n) +y, grad = test_NLop(P,x,r,verb) +@test norm((A*x)*(B*x)-y) < 1e-9 + +n = 3 +A,B,C = Sin(n,n), Cos(n,n), Atan(n,n) +P = Ax_mul_Bx(A,B) +P2 = Ax_mul_Bx(C,P) +x = randn(n,n) +r = randn(n,n) +y, grad = test_NLop(P2,x,r,verb) +@test norm((C*x)*(A*x)*(B*x)-y) < 1e-9 + +n,l = 2,3 +A,B = MatrixOp(randn(l,n),l), MatrixOp(randn(l,n),l) +P = Ax_mul_Bx(A,B) +x = randn(n,l) +r = randn(l,l) +y, grad = test_NLop(P,x,r,verb) +@test norm((A*x)*(B*x)-y) < 1e-8 + +@test_throws Exception Ax_mul_Bx(Eye(2), Eye(2)) +@test_throws Exception Ax_mul_Bx(Eye(2,2), Eye(2,1)) +@test_throws Exception Ax_mul_Bx(Eye(2,2,2), Eye(2,2,2)) + +# testing with HCAT +m,n = 3,5 +x = ArrayPartition(randn(n,n),randn(m,n)) +r = randn(n,n) +b = randn(n,n) +A = AffineAdd(Sin(Float64,(n,n)),b) +B = MatrixOp(randn(n,m),n) +op1 = HCAT(A,B) +C = Sin(Float64,(n,n)) +D = MatrixOp(randn(n,m),n) +op2 = HCAT(C,D) +P = Ax_mul_Bx(op1,op2) +y, grad = test_NLop(P,x,r,verb) +@test norm((op1*x)*(op2*x)-y) < 1e-8 + +#test remove_displacement +y2, grad = test_NLop(remove_displacement(P),x,r,verb) +@test norm((op1*x-b)*(op2*x)-y2) < 1e-8 + +# test permute +p = [2,1] +Pp = AbstractOperators.permute(P,p) +xp = ArrayPartition(x.x[p]) +y2, grad = test_NLop(Pp,xp,r,verb) +@test norm(y2-y) < 1e-8 + +#### some combos of Ax_mul_Bx etc... +n,m,l = 3,7,5 +A,B = MatrixOp(randn(n,m),l), MatrixOp(randn(n,m),l) +P = Ax_mul_Bxt(A,B) +P2 = Axt_mul_Bx(A,P) +x = randn(m,l) +r = randn(l,n) +y, grad = test_NLop(P2,x,r,verb) +@test norm((A*x)'*((A*x)*(B*x)')-y) < 1e-8 + +n,m,l,k = 3,7,5,9 +A,B = MatrixOp(randn(n,m),l), MatrixOp(randn(n,m),l) +C = MatrixOp(randn(k,m),l) +P = Axt_mul_Bx(A,B) +P2 = Ax_mul_Bx(C,P) +x = randn(m,l) +r = randn(k,l) +y, grad = test_NLop(P2,x,r,verb) +@test norm((C*x)*((A*x)'*(B*x))-y) < 1e-8 + +n,m,l,k = 3,7,5,9 +A,B = MatrixOp(randn(n,m),l), MatrixOp(randn(n,m),l) +C = MatrixOp(randn(k,m),l) +P = Axt_mul_Bx(A,B) +P2 = Ax_mul_Bxt(C,P) +x = randn(m,l) +r = randn(k,l) +y, grad = test_NLop(P2,x,r,verb) +@test norm((C*x)*((A*x)'*(B*x))'-y) < 1e-8 + +#### HadamardProd + +n = 3 +A,B = Eye(n,n), Eye(n,n) +P = HadamardProd(A,B) +x = randn(n,n) +r = randn(n,n) +y, grad = test_NLop(P,x,r,verb) +@test norm(x.*x-y) < 1e-9 + +n,l = 3,2 +A,B = Sin(n,l), Cos(n,l) +P = HadamardProd(A,B) +x = randn(n,l) +r = randn(n,l) +y, grad = test_NLop(P,x,r,verb) +@test norm((A*x).*(B*x)-y) < 1e-9 + +# testing with HCAT +m,n = 3,5 +x = ArrayPartition(randn(m),randn(n)) +r = randn(m) +b = randn(m) +A = AffineAdd(Sin(Float64,(m,)),b) +B = MatrixOp(randn(m,n)) +op1 = HCAT(A,B) +C = Cos(Float64,(m,)) +D = MatrixOp(randn(m,n)) +op2 = HCAT(C,D) +P = HadamardProd(op1,op2) +y, grad = test_NLop(P,x,r,verb) +@test norm((op1*x).*(op2*x)-y) < 1e-9 + +#test remove_displacement +y2, grad = test_NLop(remove_displacement(P),x,r,verb) +@test norm((op1*x-b).*(op2*x)-y2) < 1e-8 + +# test permute +p = [2,1] +Pp = AbstractOperators.permute(P,p) +xp = ArrayPartition(x.x[p]) +y2, grad = test_NLop(Pp,xp,r,verb) +@test norm(y2-y) < 1e-8 +@test_throws Exception HadamardProd(Eye(2,2,2), Eye(1,2,2))