Skip to content

Commit

Permalink
Jax euler inverse (#24)
Browse files Browse the repository at this point in the history
* use `solve(A, Bv)` instead of `inv(A) @ B` in `jax_euler_step`

* Remove dead code

* Fix euler step test code
  • Loading branch information
jessegrabowski authored Aug 2, 2024
1 parent 6962d59 commit b6bd3bf
Show file tree
Hide file tree
Showing 3 changed files with 2 additions and 12 deletions.
7 changes: 0 additions & 7 deletions cge_modeling/base/cge.py
Original file line number Diff line number Diff line change
Expand Up @@ -891,13 +891,6 @@ def f_euler(theta_final, n_steps, progressbar=True, update_freq=1, **data):
current_variable_vals = current_step[: len(self.variable_names)]
current_parameter_vals = current_step[len(self.variable_names) :]

# current_variable_vals = flat_current_step[
# : len(self.variable_names)
# ]
# current_parameter_vals = flat_current_step[
# len(self.variable_names) :
# ]

current_variables = {
k: current_variable_vals[i] for i, k in enumerate(self.variable_names)
}
Expand Down
4 changes: 1 addition & 3 deletions cge_modeling/pytensorf/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,10 +472,8 @@ def step(**kwargs):
step_size = (theta_final_vec - theta0_vec) / n_steps

A = jax.jacobian(f_sys, 0)(x_vec, theta_vec)
A_inv = jax.scipy.linalg.solve(A, jnp.eye(A.shape[0]))

_, Bv = jax.jvp(lambda theta: f_sys(x_vec, theta), (theta_vec,), (step_size,))
step = -A_inv @ Bv
step = -jax.scipy.linalg.solve(A, Bv)

x_next_vec = x_vec + step
theta_next_vec = theta_vec + step_size
Expand Down
3 changes: 1 addition & 2 deletions tests/test_pytensorf.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,9 @@ def test_euler_approximation_1d():
n_steps = pt.iscalar("n_steps")

A = make_jacobian(eq, [y])
A_inv = 1 / A
B = make_jacobian(eq, [x])

x0_final, result = euler_approximation(A_inv, B, variables=[y], parameters=[x], n_steps=n_steps)
x0_final, result = euler_approximation(A, B, variables=[y], parameters=[x], n_steps=n_steps)
f = pytensor.function([x, y, x0_final, n_steps], result)
y_values, x_values = f(0, 0, np.array([10.0]), 10_000)
true = -np.cos(np.array([10])) + 1
Expand Down

0 comments on commit b6bd3bf

Please sign in to comment.