Skip to content

Commit

Permalink
Return a single ragged list from f_step in JAX mode
Browse files Browse the repository at this point in the history
Add JAX mode tests for `test_solvers_agree`
  • Loading branch information
jessegrabowski committed Aug 2, 2024
1 parent 3240c4e commit 56caa49
Show file tree
Hide file tree
Showing 5 changed files with 25 additions and 6 deletions.
9 changes: 9 additions & 0 deletions cge_modeling/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,15 @@
from cge_modeling.tools.output_tools import display_info_as_table, latex_print_equations
from cge_modeling.pytensorf.rewrites import prod_to_no_zero_prod # noqa: F401

import logging

_log = logging.getLogger("cge_modeling")
if not logging.root.handlers:
_log.setLevel(logging.INFO)
if len(_log.handlers) == 0:
handler = logging.StreamHandler()
_log.addHandler(handler)


__version__ = "0.0.1"

Expand Down
2 changes: 1 addition & 1 deletion cge_modeling/base/cge.py
Original file line number Diff line number Diff line change
Expand Up @@ -789,7 +789,7 @@ def f_jac(*args, **kwargs):
f_resid, f_grad, f_hess, f_hessp = None, None, None, None

if mode == "JAX":
f_resid, f_grad, f_hessp = jax_loss_grad_hessp(
f_resid, f_grad, f_hess, f_hessp = jax_loss_grad_hessp(
system, variables, parameters
)
else:
Expand Down
13 changes: 10 additions & 3 deletions cge_modeling/pytensorf/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,20 +406,27 @@ def f_loss_wrapped(x, theta):

f_loss = jax.jit(f_loss_wrapped)

grad = jax.grad(f_loss, 0)
grad = jax.grad(f_loss_wrapped, 0)

def f_grad_jax(x, theta):
return jnp.stack(grad(x, theta))

f_grad = jax.jit(f_grad_jax)

_f_hess_jax = jax.jacfwd(f_grad_jax, argnums=0)

def f_hess_jax(x, theta):
return jnp.stack(_f_hess_jax(x, theta))

f_hess = jax.jit(f_hess_jax)

def f_hessp_jax(x, p, theta):
_, u = jax.jvp(lambda x: f_grad_jax(x, theta), (x,), (p,))
return jnp.stack(u)

f_hessp = jax.jit(f_hessp_jax)

return f_loss, f_grad, f_hessp
return f_loss, f_grad, f_hess, f_hessp


def jax_euler_step(system, variables, parameters):
Expand Down Expand Up @@ -479,7 +486,7 @@ def step(**kwargs):
x_next = flat_tensor_to_ragged_list(x_next_vec, x_shapes)
theta_next = flat_tensor_to_ragged_list(theta_next_vec, theta_shapes)

return x_next, theta_next
return [*x_next, *theta_next]

f_step = jax.jit(step)
return f_step
Expand Down
1 change: 0 additions & 1 deletion cge_modeling/pytensorf/rewrites.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ def prod_to_no_zero_prod(fgraph, node):
Note that this only affects product reduction Ops, it's not the same as a multiplication.
"""
if isinstance(node.op, Prod) and not node.op.no_zeros_in_input:
print("hi :)")
(x,) = node.inputs
new_op = Prod(
dtype=node.op.dtype, acc_dtype=node.op.dtype, no_zeros_in_input=True
Expand Down
6 changes: 5 additions & 1 deletion tests/test_cge_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,16 +437,20 @@ def test_pytensor_from_sympy(
],
ids=["minimize", "root", "euler"],
)
@pytest.mark.parametrize("mode", ["FAST_RUN", "JAX"], ids=["FAST_RUN", "JAX"])
def test_backends_agree(
model_function: Callable,
calibrate_model: Callable,
data: dict,
method: str,
solver_kwargs: dict,
mode: str,
):
if mode in ["JAX"]:
pytest.importorskip(mode.lower())
model_numba = model_function(backend="numba")
model_pytensor = model_function(
backend="pytensor", mode="FAST_COMPILE", parse_equations_to_sympy=False
backend="pytensor", mode=mode, parse_equations_to_sympy=False
)

def solver_agreement_checks(results: list, names: list):
Expand Down

0 comments on commit 56caa49

Please sign in to comment.