Skip to content

Commit

Permalink
allow single config in solver (#92)
Browse files Browse the repository at this point in the history
  • Loading branch information
GiggleLiu authored Dec 23, 2024
1 parent 4759afb commit 8432a0a
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 8 deletions.
33 changes: 25 additions & 8 deletions src/interfaces.jl
Original file line number Diff line number Diff line change
Expand Up @@ -489,16 +489,33 @@ size_all_negative(::SpinGlass) = false
size_all_positive(::SpinGlass) = false

# NOTE: `findmin` and `findmax` are required by `ProblemReductions.jl`
"""
GTNSolver(; optimizer=TreeSA(), single=false, usecuda=false, T=Float64)
A generic tensor network based backend for the `findbest`, `findmin` and `findmax` interfaces in `ProblemReductions.jl`.
Keyword arguments
-------------------------------------
* `optimizer` is the optimizer for the tensor network contraction.
* `single` is a switch to return single solution instead of all solutions.
* `usecuda` is a switch to use CUDA (when applicable), user need to call statement `using CUDA` before turning on this switch.
* `T` is the "base" element type, sometimes can be used to reduce the memory cost.
"""
Base.@kwdef struct GTNSolver
optimizer::OMEinsum.CodeOptimizer = TreeSA()
single::Bool = false
usecuda::Bool = false
T::Type = Float64
end
function Base.findmin(problem::AbstractProblem, solver::GTNSolver)
res = collect(solve(GenericTensorNetwork(problem; optimizer=solver.optimizer), ConfigsMin(; tree_storage=true); usecuda=solver.usecuda, T=solver.T)[].c)
return map(x -> ProblemReductions.id_to_config(problem, Int.(x) .+ 1), res)
end
function Base.findmax(problem::AbstractProblem, solver::GTNSolver)
res = collect(solve(GenericTensorNetwork(problem; optimizer=solver.optimizer), ConfigsMax(; tree_storage=true); usecuda=solver.usecuda, T=solver.T)[].c)
return map(x -> ProblemReductions.id_to_config(problem, Int.(x) .+ 1), res)
end
for (PROP, SPROP, SOLVER) in [
(:ConfigsMin, :SingleConfigMin, :findmin), (:ConfigsMax, :SingleConfigMax, :findmax)
]
@eval function Base.$(SOLVER)(problem::AbstractProblem, solver::GTNSolver)
if solver.single
res = [solve(GenericTensorNetwork(problem; optimizer=solver.optimizer), $(SPROP)(); usecuda=solver.usecuda, T=solver.T)[].c.data]
else
res = collect(solve(GenericTensorNetwork(problem; optimizer=solver.optimizer), $(PROP)(; tree_storage=true); usecuda=solver.usecuda, T=solver.T)[].c)
end
return map(x -> ProblemReductions.id_to_config(problem, Int.(x) .+ 1), res)
end
end
4 changes: 4 additions & 0 deletions test/interfaces.jl
Original file line number Diff line number Diff line change
Expand Up @@ -254,4 +254,8 @@ end
solver2 = BruteForce()
@test Set(findmin(sg, solver1)) == Set(findmin(sg, solver2))
@test Set(findmax(sg, solver1)) == Set(findmax(sg, solver2))

solver3 = GTNSolver(; optimizer=TreeSA(ntrials=1), single=true)
@test findmin(sg, solver3)[] findmin(sg, solver2)
@test findmax(sg, solver3)[] findmax(sg, solver2)
end

0 comments on commit 8432a0a

Please sign in to comment.