diff --git a/CHANGELOG.md b/CHANGELOG.md index eaa0f58a2..14542a3e8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -24,10 +24,10 @@ New Features Bug Fixes -- Fixes bug that occurs when taking the gradient of ``root`` and ``root_scalar`` with newer versions of JAX (>=0.4.34) and unpins the JAX version +- Fixes bug that occurs when taking the gradient of ``root`` and ``root_scalar`` with newer versions of JAX (>=0.4.34) and unpins the JAX version. - Changes ``FixLambdaGauge`` constraint to now enforce zero flux surface average for lambda, instead of enforcing lambda(rho,0,0)=0 as it was incorrectly doing before. - Fixes bug in ``softmin/softmax`` implementation. - +- Fixes bug that occured when using ``ProximalProjection`` with a scalar optimization algorithm. v0.12.3 ------- diff --git a/desc/optimize/_constraint_wrappers.py b/desc/optimize/_constraint_wrappers.py index dc2e3033c..7695a671b 100644 --- a/desc/optimize/_constraint_wrappers.py +++ b/desc/optimize/_constraint_wrappers.py @@ -836,6 +836,25 @@ def compute_scaled_error(self, x, constants=None): xopt, _ = self._update_equilibrium(x, store=False) return self._objective.compute_scaled_error(xopt, constants[0]) + def compute_scalar(self, x, constants=None): + """Compute the sum of squares error. + + Parameters + ---------- + x : ndarray + State vector. + constants : list + Constant parameters passed to sub-objectives. + + Returns + ------- + f : float + Objective function scalar value. + + """ + f = jnp.sum(self.compute_scaled_error(x, constants=constants) ** 2) / 2 + return f + def compute_unscaled(self, x, constants=None): """Compute the raw value of the objective function. diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py index 4f0cce0e0..9a4aac89b 100644 --- a/tests/test_optimizer.py +++ b/tests/test_optimizer.py @@ -324,6 +324,44 @@ def test_no_iterations(): np.testing.assert_allclose(x0, out2["x"]) +@pytest.mark.regression +@pytest.mark.optimize +def test_proximal_scalar(): + """Test that proximal scalar optimization works.""" + # test fix for GH issue #1403 + + # optimize to reduce DSHAPE volume from 100 m^3 to 90 m^3 + eq = desc.examples.get("DSHAPE") + optimizer = Optimizer("proximal-fmintr") # proximal scalar optimizer + R_modes = np.vstack( + ( + [0, 0, 0], + eq.surface.R_basis.modes[ + np.max(np.abs(eq.surface.R_basis.modes), 1) > 1, : + ], + ) + ) + Z_modes = eq.surface.Z_basis.modes[ + np.max(np.abs(eq.surface.Z_basis.modes), 1) > 1, : + ] + objective = ObjectiveFunction(Volume(eq=eq, target=90)) # scalar objective function + constraints = ( + FixBoundaryR(eq=eq, modes=R_modes), + FixBoundaryZ(eq=eq, modes=Z_modes), + FixIota(eq=eq), + FixPressure(eq=eq), + FixPsi(eq=eq), + ForceBalance(eq=eq), # force balance constraint for proximal projection + ) + [eq], _ = optimizer.optimize( + things=eq, + objective=objective, + constraints=constraints, + verbose=3, + ) + np.testing.assert_allclose(eq.compute("V")["V"], 90) + + @pytest.mark.regression @pytest.mark.slow @pytest.mark.optimize