Skip to content

Commit

Permalink
New calculus rules (#10)
Browse files Browse the repository at this point in the history
* `Ax_mul_Bx` --> Generalizes `NonLinearCompose`
* `Axt_mul_Bx`
* `Ax_mul_Bxt`
* `HadamardProd` --> Generalizes `Hadamard`

`Hadamard` & `NonLinearCompose` will be deprecated in future version of AbstractOperators.
  • Loading branch information
nantonel authored Mar 21, 2019
1 parent 8e0fbd8 commit e16bbe4
Show file tree
Hide file tree
Showing 12 changed files with 916 additions and 168 deletions.
6 changes: 4 additions & 2 deletions docs/src/calculus.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,10 @@ DCAT

```@docs
Compose
NonLinearCompose
Hadamard
HadamardProd
Ax_mul_Bx
Axt_mul_Bx
Ax_mul_Bxt
```

## Transformations
Expand Down
4 changes: 4 additions & 0 deletions src/AbstractOperators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,11 @@ include("calculus/Sum.jl")
include("calculus/AffineAdd.jl")
include("calculus/Jacobian.jl")
include("calculus/NonLinearCompose.jl")
include("calculus/Axt_mul_Bx.jl")
include("calculus/Ax_mul_Bxt.jl")
include("calculus/Ax_mul_Bx.jl")
include("calculus/Hadamard.jl")
include("calculus/HadamardProd.jl")

# Non-Linear operators
include("nonlinearoperators/Pow.jl")
Expand Down
110 changes: 110 additions & 0 deletions src/calculus/Ax_mul_Bx.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
#Ax_mul_Bx

export Ax_mul_Bx

"""
`Ax_mul_Bx(A::AbstractOperator,B::AbstractOperator)`
Create an operator `P` such that:
`P*x == (Ax)*(Bx)`
# Example
```julia
julia> A,B = randn(4,4),randn(4,4);
julia> P = Ax_mul_Bx(MatrixOp(A,4),MatrixOp(B,4))
▒*▒ ℝ^4 -> ℝ^(4, 4)
julia> X = randn(4,4);
julia> P*X == (A*X)*(B*X)
true
```
"""
struct Ax_mul_Bx{
L1 <: AbstractOperator,
L2 <: AbstractOperator,
C <: AbstractArray,
D <: AbstractArray,
} <: NonLinearOperator
A::L1
B::L2
bufA::C
bufB::C
bufC::C
bufD::D
function Ax_mul_Bx(A::L1, B::L2, bufA::C, bufB::C, bufC::C, bufD::D) where {L1,L2,C,D}
if ndims(A,1) != 2 || size(A,2) != size(B,2) || size(A,1)[2] != size(B,1)[1]
throw(DimensionMismatch("Cannot compose operators"))
end
new{L1,L2,C,D}(A,B,bufA,bufB,bufC,bufD)
end
end

struct Ax_mul_BxJac{
L1 <: AbstractOperator,
L2 <: AbstractOperator,
C <: AbstractArray,
D <: AbstractArray,
} <: LinearOperator
A::L1
B::L2
bufA::C
bufB::C
bufC::C
bufD::D
end

# Constructors
function Ax_mul_Bx(A::AbstractOperator,B::AbstractOperator)
s,t = size(A,1), codomainType(A)
bufA = eltype(s) <: Int ? zeros(t,s) : ArrayPartition(zeros.(t,s)...)
s,t = size(B,1), codomainType(B)
bufB = eltype(s) <: Int ? zeros(t,s) : ArrayPartition(zeros.(t,s)...)
bufC = eltype(s) <: Int ? zeros(t,s) : ArrayPartition(zeros.(t,s)...)
s,t = size(A,2), domainType(A)
bufD = eltype(s) <: Int ? zeros(t,s) : ArrayPartition(zeros.(t,s)...)
Ax_mul_Bx(A,B,bufA,bufB,bufC,bufD)
end

# Jacobian
function Jacobian(P::Ax_mul_Bx{L1,L2,C,D}, x::AbstractArray) where {L1,L2,C,D}
JA, JB = Jacobian(P.A, x), Jacobian(P.B, x)
Ax_mul_BxJac{typeof(JA),typeof(JB),C,D}(JA,JB,P.bufA,P.bufB,P.bufC,P.bufD)
end

# Mappings
function mul!(y, P::Ax_mul_Bx{L1,L2,C,D}, b) where {L1,L2,C,D}
mul!(P.bufA,P.A,b)
mul!(P.bufB,P.B,b)
mul!(y,P.bufA, P.bufB)
end

function mul!(y, J::AdjointOperator{Ax_mul_BxJac{L1,L2,C,D}}, b) where {L1,L2,C,D}
#y .= J.A.B' * ( J.A.bufA'*b ) + J.A.A' * ( b*J.A.bufB' )
mul!(J.A.bufC, J.A.bufA', b)
mul!(y, J.A.B', J.A.bufC)
mul!(J.A.bufA, b, J.A.bufB')
mul!(J.A.bufD, J.A.A', J.A.bufA)
y .+= J.A.bufD
return y
end

size(P::Union{Ax_mul_Bx,Ax_mul_BxJac}) = ((size(P.A,1)[1],size(P.B,1)[2]),size(P.A,2))

fun_name(L::Union{Ax_mul_Bx,Ax_mul_BxJac}) = fun_name(L.A)*"*"*fun_name(L.B)

domainType(L::Union{Ax_mul_Bx,Ax_mul_BxJac}) = domainType(L.A)
codomainType(L::Union{Ax_mul_Bx,Ax_mul_BxJac}) = codomainType(L.A)

# utils
function permute(P::Ax_mul_Bx{L1,L2,C,D},
p::AbstractVector{Int}) where {L1,L2,C,D <:ArrayPartition}
Ax_mul_Bx(permute(P.A,p),permute(P.B,p),P.bufA,P.bufB,P.bufC,ArrayPartition(P.bufD.x[p]) )
end

remove_displacement(P::Ax_mul_Bx) =
Ax_mul_Bx(remove_displacement(P.A), remove_displacement(P.B), P.bufA, P.bufB, P.bufC, P.bufD)
118 changes: 118 additions & 0 deletions src/calculus/Ax_mul_Bxt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
#Ax_mul_Bxt

export Ax_mul_Bxt

"""
`Ax_mul_Bxt(A::AbstractOperator,B::AbstractOperator)`
Create an operator `P` such that:
`P == (Ax)*(Bx)'`
# Example: Matrix multiplication
```julia
julia> A,B = randn(4,4),randn(4,4);
julia> P = Ax_mul_Bxt(MatrixOp(A),MatrixOp(B))
▒*▒ ℝ^4 -> ℝ^(4, 4)
julia> x = randn(4);
julia> P*x == (A*x)*(B*x)'
true
```
"""
struct Ax_mul_Bxt{
L1 <: AbstractOperator,
L2 <: AbstractOperator,
C <: AbstractArray,
D <: AbstractArray,
} <: NonLinearOperator
A::L1
B::L2
bufA::C
bufB::C
bufC::C
bufD::D
function Ax_mul_Bxt(A::L1, B::L2, bufA::C, bufB::C, bufC::C, bufD::D) where {L1,L2,C,D}
if ndims(A,1) == 1
if size(A) != size(B)
throw(DimensionMismatch("Cannot compose operators"))
end
elseif ndims(A,1) == 2 && ndims(B,1) == 2 && size(A,2) == size(B,2)
if size(A,1)[2] != size(B,1)[2]
throw(DimensionMismatch("Cannot compose operators"))
end
else
throw(DimensionMismatch("Cannot compose operators"))
end
new{L1,L2,C,D}(A,B,bufA,bufB,bufC,bufD)
end
end

struct Ax_mul_BxtJac{
L1 <: AbstractOperator,
L2 <: AbstractOperator,
C <: AbstractArray,
D <: AbstractArray,
} <: LinearOperator
A::L1
B::L2
bufA::C
bufB::C
bufC::C
bufD::D
end

# Constructors
function Ax_mul_Bxt(A::AbstractOperator,B::AbstractOperator)
s,t = size(A,1), codomainType(A)
bufA = eltype(s) <: Int ? zeros(t,s) : ArrayPartition(zeros.(t,s)...)
bufC = eltype(s) <: Int ? zeros(t,s) : ArrayPartition(zeros.(t,s)...)
s,t = size(B,1), codomainType(B)
bufB = eltype(s) <: Int ? zeros(t,s) : ArrayPartition(zeros.(t,s)...)
s,t = size(A,2), domainType(A)
bufD = eltype(s) <: Int ? zeros(t,s) : ArrayPartition(zeros.(t,s)...)
Ax_mul_Bxt(A,B,bufA,bufB,bufC,bufD)
end

# Jacobian
function Jacobian(P::Ax_mul_Bxt{L1,L2,C,D}, x::AbstractArray) where {L1,L2,C,D}
JA, JB = Jacobian(P.A, x), Jacobian(P.B, x)
Ax_mul_BxtJac{typeof(JA),typeof(JB),C,D}(JA,JB,P.bufA,P.bufB,P.bufC,P.bufD)
end

# Mappings
function mul!(y, P::Ax_mul_Bxt{L1,L2,C,D}, b) where {L1,L2,C,D}
mul!(P.bufA,P.A,b)
mul!(P.bufB,P.B,b)
mul!(y,P.bufA, P.bufB')
end

function mul!(y, J::AdjointOperator{Ax_mul_BxtJac{L1,L2,C,D}}, b) where {L1,L2,C,D}
#y .= J.A.A'*(b*(J.A.bufB)) + J.A.B'*(b'*(J.A.bufA))
mul!(J.A.bufC, b, J.A.bufB)
mul!(y, J.A.A', J.A.bufC)
mul!(J.A.bufB, b', J.A.bufA)
mul!(J.A.bufD, J.A.B', J.A.bufB)
y .+= J.A.bufD
return y
end

size(P::Union{Ax_mul_Bxt,Ax_mul_BxtJac}) = ((size(P.A,1)[1],size(P.B,1)[1]),size(P.A,2))

fun_name(L::Union{Ax_mul_Bxt,Ax_mul_BxtJac}) = fun_name(L.A)*"*"*fun_name(L.B)

domainType(L::Union{Ax_mul_Bxt,Ax_mul_BxtJac}) = domainType(L.A)
codomainType(L::Union{Ax_mul_Bxt,Ax_mul_BxtJac}) = codomainType(L.A)

# utils
function permute(P::Ax_mul_Bxt{L1,L2,C,D},
p::AbstractVector{Int}) where {L1,L2,C,D <:ArrayPartition}
Ax_mul_Bxt(permute(P.A,p),permute(P.B,p),P.bufA,P.bufB,P.bufC,ArrayPartition(P.bufD.x[p]) )
end

remove_displacement(P::Ax_mul_Bxt) =
Ax_mul_Bxt(remove_displacement(P.A), remove_displacement(P.B), P.bufA, P.bufB, P.bufC, P.bufD)
137 changes: 137 additions & 0 deletions src/calculus/Axt_mul_Bx.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
#Axt_mul_Bx

export Axt_mul_Bx

"""
`Axt_mul_Bx(A::AbstractOperator,B::AbstractOperator)`
Create an operator `P` such that:
`P*x == (Ax)'*(Bx)`
# Example
```julia
julia> A,B = randn(4,4),randn(4,4);
julia> P = Axt_mul_Bx(MatrixOp(A),MatrixOp(B))
▒*▒ ℝ^4 -> ℝ^1
julia> x = randn(4);
julia> P*x == [(A*x)'*(B*x)]
true
```
"""
struct Axt_mul_Bx{N,
L1 <: AbstractOperator,
L2 <: AbstractOperator,
C <: AbstractArray,
D <: AbstractArray,
} <: NonLinearOperator
A::L1
B::L2
bufA::C
bufB::C
bufC::C
bufD::D
function Axt_mul_Bx(A::L1, B::L2, bufA::C, bufB::C, bufC::C, bufD::D) where {L1,L2,C,D}
if ndims(A,1) == 1
if size(A) != size(B)
throw(DimensionMismatch("Cannot compose operators"))
end
elseif ndims(A,1) == 2 && ndims(B,1) == 2 && size(A,2) == size(B,2)
if size(A,1)[1] != size(B,1)[1]
throw(DimensionMismatch("Cannot compose operators"))
end
else
throw(DimensionMismatch("Cannot compose operators"))
end
N = ndims(A,1)
new{N,L1,L2,C,D}(A,B,bufA,bufB,bufC,bufD)
end
end

struct Axt_mul_BxJac{N,
L1 <: AbstractOperator,
L2 <: AbstractOperator,
C <: AbstractArray,
D <: AbstractArray,
} <: LinearOperator
A::L1
B::L2
bufA::C
bufB::C
bufC::C
bufD::D
end

# Constructors
function Axt_mul_Bx(A::AbstractOperator,B::AbstractOperator)
s,t = size(A,1), codomainType(A)
bufA = eltype(s) <: Int ? zeros(t,s) : ArrayPartition(zeros.(t,s)...)
bufC = eltype(s) <: Int ? zeros(t,s) : ArrayPartition(zeros.(t,s)...)
s,t = size(B,1), codomainType(B)
bufB = eltype(s) <: Int ? zeros(t,s) : ArrayPartition(zeros.(t,s)...)
s,t = size(A,2), domainType(A)
bufD = eltype(s) <: Int ? zeros(t,s) : ArrayPartition(zeros.(t,s)...)
Axt_mul_Bx(A,B,bufA,bufB,bufC,bufD)
end

# Jacobian
function Jacobian(P::Axt_mul_Bx{N,L1,L2,C,D}, x::AbstractArray) where {N,L1,L2,C,D}
JA, JB = Jacobian(P.A, x), Jacobian(P.B, x)
Axt_mul_BxJac{N,typeof(JA),typeof(JB),C,D}(JA,JB,P.bufA,P.bufB,P.bufC,P.bufD)
end

# Mappings
# N == 1 input is a vector
function mul!(y, P::Axt_mul_Bx{1,L1,L2,C,D}, b) where {L1,L2,C,D}
mul!(P.bufA,P.A,b)
mul!(P.bufB,P.B,b)
y[1] = dot(P.bufA,P.bufB)
end

function mul!(y, J::AdjointOperator{Axt_mul_BxJac{1,L1,L2,C,D}}, b) where {L1,L2,C,D}
#y .= conj(J.A.A'*J.A.bufB+J.A.B'*J.A.bufA).*b[1]
mul!(y, J.A.A', J.A.bufB)
mul!(J.A.bufD, J.A.B', J.A.bufA)
y .= conj.( y .+ J.A.bufD ) .* b[1]
return y
end

# N == 2 input is a matrix
function mul!(y, P::Axt_mul_Bx{2,L1,L2,C,D}, b) where {L1,L2,C,D}
mul!(P.bufA,P.A,b)
mul!(P.bufB,P.B,b)
mul!(y,P.bufA',P.bufB)
return y
end

function mul!(y, J::AdjointOperator{Axt_mul_BxJac{2,L1,L2,C,D}}, b) where {L1,L2,C,D}
# y .= J.A.A'*((J.A.bufB)*b') + J.A.B'*((J.A.bufA)*b)
mul!(J.A.bufC, J.A.bufB, b')
mul!(y, J.A.A', J.A.bufC)
mul!(J.A.bufB, J.A.bufA, b)
mul!(J.A.bufD, J.A.B', J.A.bufB)
y .+= J.A.bufD
return y
end

size(P::Union{Axt_mul_Bx{1},Axt_mul_BxJac{1}}) = ((1,),size(P.A,2))
size(P::Union{Axt_mul_Bx{2},Axt_mul_BxJac{2}}) = ((size(P.A,1)[2],size(P.B,1)[2]),size(P.A,2))

fun_name(L::Union{Axt_mul_Bx,Axt_mul_BxJac}) = fun_name(L.A)*"*"*fun_name(L.B)

domainType(L::Union{Axt_mul_Bx,Axt_mul_BxJac}) = domainType(L.A)
codomainType(L::Union{Axt_mul_Bx,Axt_mul_BxJac}) = codomainType(L.A)

# utils
function permute(P::Axt_mul_Bx{N,L1,L2,C,D},
p::AbstractVector{Int}) where {N,L1,L2,C,D <:ArrayPartition}
Axt_mul_Bx(permute(P.A,p),permute(P.B,p),P.bufA,P.bufB,P.bufC,ArrayPartition(P.bufD.x[p]) )
end

remove_displacement(P::Axt_mul_Bx) =
Axt_mul_Bx(remove_displacement(P.A), remove_displacement(P.B), P.bufA, P.bufB, P.bufC, P.bufD)
Loading

0 comments on commit e16bbe4

Please sign in to comment.