Skip to content

Commit

Permalink
Fixed "model is still borrowed" bug
Browse files Browse the repository at this point in the history
  • Loading branch information
Dzuchun committed Nov 25, 2024
1 parent 658b707 commit eb8d179
Show file tree
Hide file tree
Showing 9 changed files with 27 additions and 91 deletions.
65 changes: 9 additions & 56 deletions nacfahi/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,7 @@ where
report: MinimizationReport<Model::Scalar>,
x: X,
y: Y,
) -> FitStat<Model>;
) -> FitStat<Model::Scalar, Model::ParamCount, Model::OwnedModel>;
}

impl<Model, X, Y> FitErrBound<Model, X, Y> for FitterUnit
Expand Down Expand Up @@ -354,7 +354,7 @@ where
report: MinimizationReport<Model::Scalar>,
x: X,
y: Y,
) -> FitStat<Model>
) -> FitStat<Model::Scalar, Model::ParamCount, Model::OwnedModel>
where
Model: FitModelErrors,
{
Expand Down Expand Up @@ -402,7 +402,7 @@ where
let param_errors = (0usize..u_params)
.map(|i| Float::sqrt(covariance_matrix[(i, i)]))
.collect();
let errors = model.with_errors(param_errors);
let errors = Model::with_errors(param_errors);

FitStat {
report,
Expand Down Expand Up @@ -483,64 +483,17 @@ where

/// Result of [`function@fit_stat`].
#[derive(Debug)]
pub struct FitStat<Model: FitModelErrors>
where
Model::Scalar: RealField,
{
pub struct FitStat<Scalar: RealField, ParamCount: Conv, OwnedModel> {
/// Report resulted from the fit
pub report: MinimizationReport<Model::Scalar>,
pub report: MinimizationReport<Scalar>,
/// $\chi^{2}/\test{dof}$ criteria. Should be about 1 for correct fit.
pub reduced_chi2: Model::Scalar,
pub reduced_chi2: Scalar,
/// Type defined by model, containing parameter errors.
///
/// This will usually be the model type itself, but there may be exceptions.
pub errors: Model::OwnedModel,
pub errors: OwnedModel,
/// A parameter covariance matrix. If you don't know what this is, you can safely ignore it.
pub covariance_matrix: GenericMatrix<Model::Scalar, Model::ParamCount, Model::ParamCount>,
}

impl<Model> FitStat<Model>
where
Model: FitModelErrors,
Model::Scalar: RealField,
{
/// Converts [`FitStat`] types, if there's no difference between them internally.
///
/// For example, this works for `FitStat<&mut Model> -> FitStat<Mode>` convertion:
///
/// ```rust
/// # use nacfahi::{models::basic::Constant, fit_stat, FitStat};
/// let x = [1.0, 3.0, -4.0];
/// let y = [-2.0, 5.2, -5.3];
///
/// let mut model = Constant { c: 0.0 };
///
/// let fit_stat: FitStat<&mut Constant<f64>> = fit_stat!(&mut model, x, y);
/// let fit_stat: FitStat<Constant<f64>> = fit_stat.into();
/// ```
#[inline]
pub fn into<OtherModel>(self) -> FitStat<OtherModel>
where
OtherModel: FitModelErrors<
Scalar = Model::Scalar,
ParamCount = Model::ParamCount,
OwnedModel = Model::OwnedModel,
>,
{
let FitStat {
report,
reduced_chi2,
errors,
covariance_matrix,
} = self;

FitStat {
report,
reduced_chi2,
errors,
covariance_matrix,
}
}
pub covariance_matrix: GenericMatrix<Scalar, ParamCount, ParamCount>,
}

#[macro_export]
Expand Down Expand Up @@ -605,7 +558,7 @@ pub fn fit_stat<Model, X, Y>(
y: Y,
minimizer: impl Borrow<LevenbergMarquardt<Model::Scalar>>,
weights: impl Fn(Model::Scalar, Model::Scalar) -> Model::Scalar,
) -> FitStat<Model>
) -> FitStat<Model::Scalar, Model::ParamCount, Model::OwnedModel>
where
Model: FitModelErrors,
Model::Scalar: RealField,
Expand Down
7 changes: 2 additions & 5 deletions nacfahi/src/models/basic/constant.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,17 +46,14 @@ impl<Scalar: Clone + Zero + One> FitModelXDeriv for Constant<Scalar> {
}
}

impl<Scalar> FitModelErrors for Constant<Scalar>
impl<Scalar: 'static> FitModelErrors for Constant<Scalar>
where
Self: FitModel<Scalar = Scalar, ParamCount = U1>,
{
type OwnedModel = Self;

#[inline]
fn with_errors(
&self,
errors: GenericArray<Self::Scalar, Self::ParamCount>,
) -> Self::OwnedModel {
fn with_errors(errors: GenericArray<Self::Scalar, Self::ParamCount>) -> Self::OwnedModel {
let [c] = errors.into_array();
Self { c }
}
Expand Down
7 changes: 2 additions & 5 deletions nacfahi/src/models/basic/exponent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,17 +62,14 @@ impl<Scalar: Clone + Mul<Output = Scalar> + Pow<Scalar, Output = Scalar> + Float
}
}

impl<Scalar> FitModelErrors for Exponent<Scalar>
impl<Scalar: 'static> FitModelErrors for Exponent<Scalar>
where
Scalar: Clone + Mul<Output = Scalar> + Pow<Scalar, Output = Scalar> + FloatConst,
{
type OwnedModel = Self;

#[inline]
fn with_errors(
&self,
errors: GenericArray<Self::Scalar, Self::ParamCount>,
) -> Self::OwnedModel {
fn with_errors(errors: GenericArray<Self::Scalar, Self::ParamCount>) -> Self::OwnedModel {
let [a, b] = errors.into_array();
Exponent { a, b }
}
Expand Down
14 changes: 4 additions & 10 deletions nacfahi/src/models/basic/gaussian.rs
Original file line number Diff line number Diff line change
Expand Up @@ -152,16 +152,13 @@ impl<Scalar: Float + FloatConst> FitModelXDeriv for Gaussian<Scalar> {
}
}

impl<Scalar> FitModelErrors for Gaussian<Scalar>
impl<Scalar: 'static> FitModelErrors for Gaussian<Scalar>
where
Self: FitModel<Scalar = Scalar, ParamCount = U3>,
{
type OwnedModel = Gaussian<Scalar>;

fn with_errors(
&self,
errors: GenericArray<Self::Scalar, Self::ParamCount>,
) -> Self::OwnedModel {
fn with_errors(errors: GenericArray<Self::Scalar, Self::ParamCount>) -> Self::OwnedModel {
let [x_c, s, a] = errors.into_array();
Gaussian { a, s, x_c }
}
Expand Down Expand Up @@ -238,17 +235,14 @@ pub struct GaussianSErr<Scalar> {
pub x_c_err: Scalar,
}

impl<Scalar> FitModelErrors for GaussianS<Scalar>
impl<Scalar: 'static> FitModelErrors for GaussianS<Scalar>
where
Self: FitModel<Scalar = Scalar, ParamCount = U2>,
{
type OwnedModel = GaussianSErr<Scalar>;

#[inline]
fn with_errors(
&self,
errors: GenericArray<Self::Scalar, Self::ParamCount>,
) -> Self::OwnedModel {
fn with_errors(errors: GenericArray<Self::Scalar, Self::ParamCount>) -> Self::OwnedModel {
let [a_err, x_c_err] = errors.into_array();
GaussianSErr { a_err, x_c_err }
}
Expand Down
7 changes: 2 additions & 5 deletions nacfahi/src/models/basic/linear.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,17 +63,14 @@ where
}
}

impl<Scalar> FitModelErrors for Linear<Scalar>
impl<Scalar: 'static> FitModelErrors for Linear<Scalar>
where
Scalar: Clone + Add<Output = Scalar> + Mul<Output = Scalar> + One,
{
type OwnedModel = Linear<Scalar>;

#[inline]
fn with_errors(
&self,
errors: GenericArray<Self::Scalar, Self::ParamCount>,
) -> Self::OwnedModel {
fn with_errors(errors: GenericArray<Self::Scalar, Self::ParamCount>) -> Self::OwnedModel {
let [a, b] = errors.into_array();
Linear { a, b }
}
Expand Down
3 changes: 1 addition & 2 deletions nacfahi/src/models/basic/polynomial.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,15 +81,14 @@ where
}
}

impl<const ORDER: usize, Scalar> FitModelErrors for Polynomial<ORDER, Scalar>
impl<const ORDER: usize, Scalar: 'static> FitModelErrors for Polynomial<ORDER, Scalar>
where
Const<ORDER>: IntoArrayLength,
Self: FitModel<Scalar = Scalar, ParamCount = Const<ORDER>>,
{
type OwnedModel = Polynomial<ORDER, Scalar>;

fn with_errors(
&self,
errors: GenericArray<Self::Scalar, <Self::ParamCount as generic_array_storage::Conv>::TNum>,
) -> Self::OwnedModel {
Polynomial {
Expand Down
10 changes: 3 additions & 7 deletions nacfahi/src/models/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,11 +84,10 @@ pub trait FitModelErrors: FitModel {
/// Type of the error model
///
/// Most of the time, this can be just `Self`.
type OwnedModel;
type OwnedModel: 'static;

/// Creates new model representing errors from the error array
fn with_errors(
&self,
errors: GenericArray<Self::Scalar, <Self::ParamCount as generic_array_storage::Conv>::TNum>,
) -> Self::OwnedModel;
}
Expand Down Expand Up @@ -150,13 +149,12 @@ impl<Model: FitModelErrors> FitModelErrors for &'_ mut Model {

#[inline]
fn with_errors(
&self,
errors: GenericArray<
Model::Scalar,
<Self::ParamCount as generic_array_storage::Conv>::TNum,
>,
) -> Self::OwnedModel {
<Model as FitModelErrors>::with_errors(self, errors)
<Model as FitModelErrors>::with_errors(errors)
}
}

Expand Down Expand Up @@ -276,12 +274,10 @@ where

#[inline]
fn with_errors(
&self,
errors: GenericArray<Self::Scalar, <Self::ParamCount as Conv>::TNum>,
) -> Self::OwnedModel {
let self_array = GenericArray::from_array(self.each_ref());
let unflat = unflatten::<_, _, <Model::ParamCount as Conv>::TNum>(errors);
self_array.zip(unflat, Model::with_errors).into_array()
unflat.map(Model::with_errors).into_array()
}
}

Expand Down
1 change: 0 additions & 1 deletion nacfahi/src/models/utility/fixed.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,6 @@ where
type OwnedModel = ();

fn with_errors(
&self,
_errors: GenericArray<Self::Scalar, <Self::ParamCount as Conv>::TNum>,
) -> Self::OwnedModel {
}
Expand Down
4 changes: 4 additions & 0 deletions nacfahi/tests/generic_fit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,8 @@ fn fit_err_generic<const ORDER: usize, Len: ArrayLength>(
};

let stat = fit_stat!(&mut model, x, y);

model.params[0] = 0.0; // test for "borrowed model"

core::hint::black_box(stat);
}

0 comments on commit eb8d179

Please sign in to comment.