Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Measure density matrix inplace #532

Merged
merged 3 commits into from
Jan 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion lib/YaoAPI/src/registers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -717,7 +717,7 @@ end
Density matrix type, where `state` is a matrix.
Type parameter `D` is the number of levels, it can also be specified by a keyword argument `nlevel`.
"""
struct DensityMatrix{D,T,MT<:AbstractMatrix{T}} <: AbstractRegister{D}
mutable struct DensityMatrix{D,T,MT<:AbstractMatrix{T}} <: AbstractRegister{D}
state::MT
end

Expand Down
9 changes: 9 additions & 0 deletions lib/YaoArrayRegister/src/density_matrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -219,3 +219,12 @@ function Base.join(r0::DensityMatrix{D}, rs::DensityMatrix{D}...) where {D}
st = kron(state(r0), state.(rs)...)
return DensityMatrix{D}(st)
end

function YaoAPI.collapseto!(rho::DensityMatrix, locsval::Pair)
locs = locsval.first isa AllLocs ? (1:nqudits(rho)) : locsval.first
ic = itercontrol(nqudits(rho), collect(locs), locsval.second) .+ 1
st = normalize!(rho.state[ic, ic])
fill!(rho.state, 0)
rho.state[ic, ic] .= st
return rho
end
54 changes: 46 additions & 8 deletions lib/YaoArrayRegister/src/measure.jl
Original file line number Diff line number Diff line change
Expand Up @@ -64,14 +64,6 @@ YaoAPI.measure(
rng::AbstractRNG = Random.GLOBAL_RNG,
) = _measure(rng, basis(reg), reg |> probs, nshots)

YaoAPI.measure(
::ComputationalBasis,
reg::DensityMatrix,
::AllLocs;
nshots::Int = 1,
rng::AbstractRNG = Random.GLOBAL_RNG,
) = _measure(rng, basis(reg), reg |> probs, nshots)

function YaoAPI.measure(
::ComputationalBasis,
reg::BatchedArrayReg,
Expand Down Expand Up @@ -141,6 +133,52 @@ function YaoAPI.measure!(
return res
end

## DensityMatrix
YaoAPI.measure(
::ComputationalBasis,
rho::DensityMatrix,
::AllLocs;
nshots::Int = 1,
rng::AbstractRNG = Random.GLOBAL_RNG,
) = _measure(rng, basis(rho), rho |> probs, nshots)


function YaoAPI.measure(op, rho::DensityMatrix, locs; kwargs...)
rrho = density_matrix(rho, locs)
res = measure(op, rrho, AllLocs(); kwargs...)
return res
end

function YaoAPI.measure!(
postprocess::PostProcess,
op::ComputationalBasis,
rho::DensityMatrix,
locs;
rng::AbstractRNG = Random.GLOBAL_RNG,
)
if !(locs isa AllLocs)
rrho = density_matrix(rho, locs)
bs = basis(rrho)
ps = rrho |> probs
else
bs = basis(rho)
ps = rho |> probs
end
res = _measure(rng, bs, ps, 1)[]
if postprocess isa RemoveMeasured
ic = itercontrol(nqudits(rho), collect(locs isa AllLocs ? (1:nqudits(rho)) : locs), res) .+ 1
rho.state = rho.state[ic, ic]
normalize!(rho)
elseif postprocess isa NoPostProcess
collapseto!(rho, locs => res)
elseif postprocess isa ResetTo
collapseto!(rho, locs => postprocess.x)
else
error("`$postprocess` is not yet supported for DensityMatrix")
end
return res
end

import YaoAPI: select, select!
select(r::AbstractArrayReg, bits::AbstractVector{T}) where {T<:Integer} =
arrayreg(r.state[Int64.(bits).+1, :]; nbatch=nbatch(r), nlevel=nlevel(r))
Expand Down
9 changes: 9 additions & 0 deletions lib/YaoArrayRegister/src/operations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@ isnormalized(r::AbstractArrayReg) =
all(sum(copy(r) |> relax!(to_nactive = nqudits(r)) |> probs, dims = 1) .≈ 1)
isnormalized(r::AdjointRegister) = isnormalized(parent(r))

function isnormalized(r::DensityMatrix)
return tr(r.state) ≈ 1
end

"""
normalize!(r::AbstractArrayReg)

Expand Down Expand Up @@ -40,6 +44,11 @@ end

LinearAlgebra.normalize!(r::AdjointRegister) = (normalize!(parent(r)); r)

function LinearAlgebra.normalize!(r::DensityMatrix)
r.state = r.state / tr(r.state)
return r
end

LinearAlgebra.norm(r::ArrayReg) = norm(statevec(r))
LinearAlgebra.norm(r::BatchedArrayReg) =
[norm(view(reshape(r.state, :, nbatch(r)), :, ib)) for ib = 1:nbatch(r)]
Expand Down
64 changes: 64 additions & 0 deletions lib/YaoArrayRegister/test/density_matrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -219,3 +219,67 @@ end
r = join(r2, r1)
@test measure(r) == [bit"101110"]
end

@testset "collapseto!" begin
r = density_matrix(ghz_state(3))
collapseto!(r, (1, 2) => (0, 0))
res = measure(r; nshots=1000)
@test all(==(bit"000"), res)
@test isnormalized(r)
end

@testset "measure on subset of qubits" begin
r1 = density_matrix(arrayreg(bit"110"))
r2 = density_matrix(arrayreg(bit"101"))
r = join(r2, r1)
@test measure(r, (1, 2)) == [bit"10"]
end

@testset "measure on density matrix, collapse" begin
# AllLocs
reg = ghz_state(3)
rho = density_matrix(reg)
res = measure!(rho)
res2 = measure(rho; nshots=10)
@test all(==(res), res2)

# specific locs
reg = ghz_state(3)
rho = density_matrix(reg, (1, 2))
res = measure!(rho, (1, 2))
res2 = measure(rho, (1, 2); nshots=10)
@test all(==(res), res2)
end


@testset "measure on density matrix, reset" begin
# AllLocs
reg = uniform_state(3)
rho = density_matrix(reg)
res = measure!(ResetTo(bit"110"), rho)
res2 = measure(rho; nshots=10)
@test all(==(bit"110"), res2)

# specific locs
reg = uniform_state(5)
rho = density_matrix(reg, (1, 2, 3))
res = measure!(ResetTo(bit"10"), rho, (1, 2))
res2 = measure(rho, (1, 2); nshots=10)
@test all(==(bit"10"), res2)
end

@testset "measure on density matrix, remove" begin
# AllLocs
reg = uniform_state(3)
rho = density_matrix(reg, (1, 2))
res = measure!(RemoveMeasured(), rho)
@test nqubits(rho) == 0
@test isnormalized(rho)

# specific locs
reg = uniform_state(5)
rho = density_matrix(reg, (1, 2, 3))
res = measure!(RemoveMeasured(), rho, (1, 2))
@test nqubits(rho) == 1
@test isnormalized(rho)
end
2 changes: 1 addition & 1 deletion lib/YaoBlocks/src/autodiff/autodiff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ module AD
using BitBasis, YaoArrayRegister, YaoAPI
using ..YaoBlocks
import ChainRulesCore:
rrule, @non_differentiable, NoTangent, Tangent, backing, AbstractTangent, ZeroTangent
rrule, @non_differentiable, NoTangent, Tangent, backing, AbstractTangent, ZeroTangent, AbstractThunk, unthunk
import YaoAPI: mat_back!, apply_back!
using SparseArrays, LuxurySparse, LinearAlgebra

Expand Down
5 changes: 3 additions & 2 deletions lib/YaoBlocks/src/autodiff/chainrules_patch.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ unsafe_primitive_tangent(x::Number) = x
for GT in [:RotationGate, :ShiftGate, :PhaseGate, :(Scale{<:Number})]
@eval function recursive_create_tangent(c::$GT)
lst = map(fieldnames(typeof(c))) do fn
fn => unsafe_primitive_tangent(getfield(c, fn))
fn => unsafe_primitive_tangent(unthunk(getfield(c, fn)))
end
nt = NamedTuple(lst)
Tangent{typeof(c),typeof(nt)}(nt)
Expand Down Expand Up @@ -46,7 +46,7 @@ for GT in [
]
@eval function recursive_create_tangent(c::$GT)
lst = map(fieldnames(typeof(c))) do fn
fn => unsafe_composite_tangent(getfield(c, fn))
fn => unsafe_composite_tangent(unthunk(getfield(c, fn)))
end
nt = NamedTuple(lst)
Tangent{typeof(c),typeof(nt)}(nt)
Expand Down Expand Up @@ -209,6 +209,7 @@ rrule(::typeof(parent), reg::AdjointArrayReg) = parent(reg), adjy -> (NoTangent(
rrule(::typeof(Base.adjoint), reg::AbstractArrayReg) =
Base.adjoint(reg), adjy -> (NoTangent(), parent(adjy))

_totype(::Type{T}, x::AbstractThunk) where {T} = _totype(T, unthunk(x))
_totype(::Type{T}, x::AbstractArray{T}) where {T} = x
_totype(::Type{T}, x::AbstractArray{T2}) where {T,T2} = convert.(T, x)
_match_type(::ArrayReg{D}, mat) where D = ArrayReg{D}(mat)
Expand Down
32 changes: 32 additions & 0 deletions lib/YaoBlocks/src/measure_ops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,35 @@ function YaoAPI.measure!(
return reg isa ArrayReg ? bb.values[res[]] : bb.values[res]
end

function YaoAPI.measure!(
::NoPostProcess,
bb::BlockedBasis,
rho::DensityMatrix{D, T},
::AllLocs;
rng::AbstractRNG = Random.GLOBAL_RNG,
) where {D,T}
state = @inbounds rho.state[bb.perm, bb.perm] # permute to make eigen values sorted
pl = diag(state)
# cummulate probabilities in each block
pl_block = zeros(eltype(pl), nblocks(bb))
for i = 1:nblocks(bb)
for k in subblock(bb, i)
pl_block[i] += pl[k]
end
end
res = sample(rng, 1:nblocks(bb), Weights(real.(pl_block)))
# collapse to the selected block
range = subblock(bb, res)
mblock = state[range, range] ./ pl_block[res]
state .= zero(T)
state[range, range] .= mblock

# undo permute and assign back
rho.state[bb.perm, bb.perm] .= state
return bb.values[res]
end


function YaoAPI.measure!(
p::ResetTo,
op::AbstractBlock,
Expand All @@ -230,6 +259,9 @@ function measure(op::AbstractBlock, reg::AbstractRegister, locs::AllLocs; kwargs
res = measure(ComputationalBasis(), copy(reg) |> V', locs; kwargs...)
diag(mat(E))[Int64.(res).+1]
end
function measure(op::AbstractBlock, rho::DensityMatrix, locs::AllLocs; kwargs...)
Base.invoke(measure, Tuple{AbstractBlock,AbstractRegister,AllLocs}, op, rho, locs; kwargs...)
end

render_mlocs(alllocs::AllLocs, locs) = locs
render_mlocs(alllocs, locs) = alllocs[locs]
Expand Down
15 changes: 6 additions & 9 deletions lib/YaoBlocks/test/autodiff/chainrules_patch.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import Zygote, ForwardDiff
using Random, Test
using YaoBlocks, YaoArrayRegister
using ChainRulesCore: Tangent
using ChainRulesCore: Tangent, unthunk, AbstractThunk

@testset "recursive_create_tangent" begin
c = chain(put(5, 2 => chain(Rx(1.4), Rx(0.5))), cnot(5, 3, 1), put(5, 3 => Rx(-0.5)))
Expand All @@ -10,16 +10,13 @@ using ChainRulesCore: Tangent
end

@testset "construtors" begin
@test Zygote.gradient(x -> x.list[1].blocks[1].theta, sum([chain(1, Rz(0.3))]))[1] == (n=nothing,
list = NamedTuple{
(:n, :blocks,),
Tuple{Nothing, Vector{NamedTuple{(:block, :theta),Tuple{Nothing,Float64}}}},
}[(n=nothing, blocks = [(block = nothing, theta = 1.0)],)],
)
@test Zygote.gradient(
res = Zygote.gradient(x -> x.list[1].blocks[1].theta, sum([chain(1, Rz(0.3))]))[1]
@test res.list[].blocks[1].theta ≈ 1.0
res = Zygote.gradient(
x -> getfield(getfield(x, :content), :theta),
Daggered(Rx(0.5)),
)[1] == (content = (block = nothing, theta = 1.0),)
)[1]
@test res.content.theta ≈ 1.0
end

@testset "rules" begin
Expand Down
10 changes: 10 additions & 0 deletions lib/YaoBlocks/test/measure_ops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -160,3 +160,13 @@ end
0.5 * sum([put(3, i=>Z) for i=1:3]),
chain([put(3, i=>H) for i=1:3]))
end

@testset "measure on density matrix" begin
reg = uniform_state(4)
rho = density_matrix(reg, 1:3)
res = measure!(kron(Z, Z, Z), rho)
p = probs(rho)
@test count(!iszero, p) == 4
res2 = measure(kron(Z, Z, Z), rho; nshots=10)
@test all(==(res), res2)
end
Loading