diff --git a/src/interfaces.jl b/src/interfaces.jl index 90c6cba..db45fad 100644 --- a/src/interfaces.jl +++ b/src/interfaces.jl @@ -489,16 +489,33 @@ size_all_negative(::SpinGlass) = false size_all_positive(::SpinGlass) = false # NOTE: `findmin` and `findmax` are required by `ProblemReductions.jl` +""" + GTNSolver(; optimizer=TreeSA(), single=false, usecuda=false, T=Float64) + +A generic tensor network based backend for the `findbest`, `findmin` and `findmax` interfaces in `ProblemReductions.jl`. + +Keyword arguments +------------------------------------- +* `optimizer` is the optimizer for the tensor network contraction. +* `single` is a switch to return single solution instead of all solutions. +* `usecuda` is a switch to use CUDA (when applicable), user need to call statement `using CUDA` before turning on this switch. +* `T` is the "base" element type, sometimes can be used to reduce the memory cost. +""" Base.@kwdef struct GTNSolver optimizer::OMEinsum.CodeOptimizer = TreeSA() + single::Bool = false usecuda::Bool = false T::Type = Float64 end -function Base.findmin(problem::AbstractProblem, solver::GTNSolver) - res = collect(solve(GenericTensorNetwork(problem; optimizer=solver.optimizer), ConfigsMin(; tree_storage=true); usecuda=solver.usecuda, T=solver.T)[].c) - return map(x -> ProblemReductions.id_to_config(problem, Int.(x) .+ 1), res) -end -function Base.findmax(problem::AbstractProblem, solver::GTNSolver) - res = collect(solve(GenericTensorNetwork(problem; optimizer=solver.optimizer), ConfigsMax(; tree_storage=true); usecuda=solver.usecuda, T=solver.T)[].c) - return map(x -> ProblemReductions.id_to_config(problem, Int.(x) .+ 1), res) -end +for (PROP, SPROP, SOLVER) in [ + (:ConfigsMin, :SingleConfigMin, :findmin), (:ConfigsMax, :SingleConfigMax, :findmax) + ] + @eval function Base.$(SOLVER)(problem::AbstractProblem, solver::GTNSolver) + if solver.single + res = [solve(GenericTensorNetwork(problem; optimizer=solver.optimizer), $(SPROP)(); usecuda=solver.usecuda, T=solver.T)[].c.data] + else + res = collect(solve(GenericTensorNetwork(problem; optimizer=solver.optimizer), $(PROP)(; tree_storage=true); usecuda=solver.usecuda, T=solver.T)[].c) + end + return map(x -> ProblemReductions.id_to_config(problem, Int.(x) .+ 1), res) + end +end \ No newline at end of file diff --git a/test/interfaces.jl b/test/interfaces.jl index 8261bee..6f49683 100644 --- a/test/interfaces.jl +++ b/test/interfaces.jl @@ -254,4 +254,8 @@ end solver2 = BruteForce() @test Set(findmin(sg, solver1)) == Set(findmin(sg, solver2)) @test Set(findmax(sg, solver1)) == Set(findmax(sg, solver2)) + + solver3 = GTNSolver(; optimizer=TreeSA(ntrials=1), single=true) + @test findmin(sg, solver3)[] ∈ findmin(sg, solver2) + @test findmax(sg, solver3)[] ∈ findmax(sg, solver2) end