Skip to content

Commit

Permalink
forward_hessian_nograd_sparse
Browse files Browse the repository at this point in the history
  • Loading branch information
stefan-k committed Mar 8, 2024
1 parent c1db03b commit ffd3802
Show file tree
Hide file tree
Showing 2 changed files with 179 additions and 0 deletions.
144 changes: 144 additions & 0 deletions python/finitediff-py/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,145 @@ fn central_jacobian_vec_prod<'py>(py: Python<'py>, f: Py<PyAny>) -> PyResult<&'p
}
}

/// Forward Hessian
#[pyfunction]
fn forward_hessian<'py>(py: Python<'py>, f: Py<PyAny>) -> PyResult<&'py PyCFunction> {
if f.as_ref(py).is_callable() {
PyCFunction::new_closure(
py,
None,
None,
move |args: &PyTuple, _kwargs: Option<&PyDict>| -> PyResult<Py<PyArray2<f64>>> {
Python::with_gil(|py| {
let out = (ndarr::forward_hessian(
&|x: &Array1<f64>| -> Result<Array1<f64>, Error> {
let x = PyArray1::from_array(py, x);
Ok(f.call(py, (x,), None)?
.extract::<&PyArray1<f64>>(py)?
.to_owned_array())
},
))(&process_args(args, 0)?)?;
Ok(out.into_pyarray(py).into())
})
},
)
} else {
not_callable!(py, f)
}
}

/// Central Hessian
#[pyfunction]
fn central_hessian<'py>(py: Python<'py>, f: Py<PyAny>) -> PyResult<&'py PyCFunction> {
if f.as_ref(py).is_callable() {
PyCFunction::new_closure(
py,
None,
None,
move |args: &PyTuple, _kwargs: Option<&PyDict>| -> PyResult<Py<PyArray2<f64>>> {
Python::with_gil(|py| {
let out = (ndarr::central_hessian(
&|x: &Array1<f64>| -> Result<Array1<f64>, Error> {
let x = PyArray1::from_array(py, x);
Ok(f.call(py, (x,), None)?
.extract::<&PyArray1<f64>>(py)?
.to_owned_array())
},
))(&process_args(args, 0)?)?;
Ok(out.into_pyarray(py).into())
})
},
)
} else {
not_callable!(py, f)
}
}

/// Forward Hessian times vec
#[pyfunction]
fn forward_hessian_vec_prod<'py>(py: Python<'py>, f: Py<PyAny>) -> PyResult<&'py PyCFunction> {
if f.as_ref(py).is_callable() {
PyCFunction::new_closure(
py,
None,
None,
move |args: &PyTuple, _kwargs: Option<&PyDict>| -> PyResult<Py<PyArray1<f64>>> {
Python::with_gil(|py| {
let out = (ndarr::forward_hessian_vec_prod(&|x: &Array1<f64>| -> Result<
Array1<f64>,
Error,
> {
let x = PyArray1::from_array(py, x);
Ok(f.call(py, (x,), None)?
.extract::<&PyArray1<f64>>(py)?
.to_owned_array())
}))(
&process_args(args, 0)?, &process_args(args, 1)?
)?;
Ok(out.into_pyarray(py).into())
})
},
)
} else {
not_callable!(py, f)
}
}

/// Central Hessian times vec
#[pyfunction]
fn central_hessian_vec_prod<'py>(py: Python<'py>, f: Py<PyAny>) -> PyResult<&'py PyCFunction> {
if f.as_ref(py).is_callable() {
PyCFunction::new_closure(
py,
None,
None,
move |args: &PyTuple, _kwargs: Option<&PyDict>| -> PyResult<Py<PyArray1<f64>>> {
Python::with_gil(|py| {
let out = (ndarr::central_hessian_vec_prod(&|x: &Array1<f64>| -> Result<
Array1<f64>,
Error,
> {
let x = PyArray1::from_array(py, x);
Ok(f.call(py, (x,), None)?
.extract::<&PyArray1<f64>>(py)?
.to_owned_array())
}))(
&process_args(args, 0)?, &process_args(args, 1)?
)?;
Ok(out.into_pyarray(py).into())
})
},
)
} else {
not_callable!(py, f)
}
}

/// Forward Hessian nograd
#[pyfunction]
fn forward_hessian_nograd<'py>(py: Python<'py>, f: Py<PyAny>) -> PyResult<&'py PyCFunction> {
if f.as_ref(py).is_callable() {
PyCFunction::new_closure(
py,
None,
None,
move |args: &PyTuple, _kwargs: Option<&PyDict>| -> PyResult<Py<PyArray2<f64>>> {
Python::with_gil(|py| {
let out = (ndarr::forward_hessian_nograd(
&|x: &Array1<f64>| -> Result<f64, Error> {
let x = PyArray1::from_array(py, x);
Ok(f.call(py, (x,), None)?.extract::<f64>(py)?)
},
))(&process_args(args, 0)?)?;
Ok(out.into_pyarray(py).into())
})
},
)
} else {
not_callable!(py, f)
}
}

/// A Python module implemented in Rust.
#[pymodule]
fn finitediff(_py: Python, m: &PyModule) -> PyResult<()> {
Expand All @@ -193,5 +332,10 @@ fn finitediff(_py: Python, m: &PyModule) -> PyResult<()> {
m.add_function(wrap_pyfunction!(central_jacobian, m)?)?;
m.add_function(wrap_pyfunction!(forward_jacobian_vec_prod, m)?)?;
m.add_function(wrap_pyfunction!(central_jacobian_vec_prod, m)?)?;
m.add_function(wrap_pyfunction!(forward_hessian, m)?)?;
m.add_function(wrap_pyfunction!(central_hessian, m)?)?;
m.add_function(wrap_pyfunction!(forward_hessian_vec_prod, m)?)?;
m.add_function(wrap_pyfunction!(central_hessian_vec_prod, m)?)?;
m.add_function(wrap_pyfunction!(forward_hessian_nograd, m)?)?;
Ok(())
}
35 changes: 35 additions & 0 deletions python/finitediff-py/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,11 @@
central_jacobian,
forward_jacobian_vec_prod,
central_jacobian_vec_prod,
forward_hessian,
central_hessian,
forward_hessian_vec_prod,
central_hessian_vec_prod,
forward_hessian_nograd,
)
import numpy as np

Expand Down Expand Up @@ -73,6 +78,36 @@ def op(x):
p = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
print(j(x, p))


def f(x):
return x[0] + x[1] ** 2 + x[2] * x[3] ** 2


def g(x):
return np.array([1.0, 2.0 * x[1], x[3] ** 2, 2.0 * x[3] * x[2]])


h = forward_hessian(g)
x = np.array([1.0, 1.0, 1.0, 1.0])
print(h(x))

h = central_hessian(g)
x = np.array([1.0, 1.0, 1.0, 1.0])
print(h(x))

h = forward_hessian_vec_prod(g)
x = np.array([1.0, 1.0, 1.0, 1.0])
p = np.array([2.0, 3.0, 4.0, 5.0])
print(h(x, p))

h = central_hessian_vec_prod(g)
x = np.array([1.0, 1.0, 1.0, 1.0])
p = np.array([2.0, 3.0, 4.0, 5.0])
print(h(x, p))

h = forward_hessian_nograd(f)
x = np.array([1.0, 1.0, 1.0, 1.0])
print(h(x))
# class NotCallable:
# pass

Expand Down

0 comments on commit ffd3802

Please sign in to comment.