Skip to content

Commit

Permalink
Made &mut mandatory in fit/fit_stat
Browse files Browse the repository at this point in the history
  • Loading branch information
Dzuchun committed Nov 25, 2024
1 parent eb8d179 commit d79c999
Show file tree
Hide file tree
Showing 5 changed files with 25 additions and 22 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -128,10 +128,10 @@ let linear_y = y.map(f64::ln);
// exponential model
let mut expo_model = Exponent { a: 1.0, b: 0.0 };
// expolinear (exponential mapped to linear)
let expolinear = model_map(&mut expo_model, LnMap);
let mut expolinear = model_map(&mut expo_model, LnMap);

// fit!
let report = fit!(expolinear, x, linear_y);
let report = fit!(&mut expolinear, x, linear_y);

# assert!(
# report.termination.was_successful(),
Expand Down
29 changes: 16 additions & 13 deletions nacfahi/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ where
#[doc(hidden)]
fn fit(
minimizer: impl Borrow<LevenbergMarquardt<Model::Scalar>>,
model: impl BorrowMut<Model>,
model: &mut Model,
x: impl Borrow<X>,
y: impl Borrow<Y>,
weights: impl Fn(Model::Scalar, Model::Scalar) -> Model::Scalar,
Expand Down Expand Up @@ -290,7 +290,7 @@ where
#[inline(always)]
fn fit(
minimizer: impl Borrow<LevenbergMarquardt<Model::Scalar>>,
mut model: impl BorrowMut<Model>,
model: &mut Model,
x: impl Borrow<X>,
y: impl Borrow<Y>,
weights: impl Fn(Model::Scalar, Model::Scalar) -> Model::Scalar,
Expand Down Expand Up @@ -321,7 +321,7 @@ where
report: MinimizationReport<Model::Scalar>,
x: X,
y: Y,
) -> FitStat<Model::Scalar, Model::ParamCount, Model::OwnedModel>;
) -> FitStat<Model>;
}

impl<Model, X, Y> FitErrBound<Model, X, Y> for FitterUnit
Expand Down Expand Up @@ -354,7 +354,7 @@ where
report: MinimizationReport<Model::Scalar>,
x: X,
y: Y,
) -> FitStat<Model::Scalar, Model::ParamCount, Model::OwnedModel>
) -> FitStat<Model>
where
Model: FitModelErrors,
{
Expand Down Expand Up @@ -467,7 +467,7 @@ macro_rules! fit {
/// **TIP**: The [`FitBound`] is an unfortunate outcome to strict type system. In case you deal with generic code, just put the `fit!` statement down, and add the bound you seemingly violate - you **should** be good after that.
#[must_use = "Minimization report is really important to check if approximation happened at all"]
pub fn fit<Model, X, Y>(
model: Model,
model: &mut Model,
x: X,
y: Y,
minimizer: impl Borrow<LevenbergMarquardt<Model::Scalar>>,
Expand All @@ -483,17 +483,20 @@ where

/// Result of [`function@fit_stat`].
#[derive(Debug)]
pub struct FitStat<Scalar: RealField, ParamCount: Conv, OwnedModel> {
pub struct FitStat<Model: FitModelErrors>
where
Model::Scalar: RealField,
{
/// Report resulted from the fit
pub report: MinimizationReport<Scalar>,
pub report: MinimizationReport<Model::Scalar>,
/// $\chi^{2}/\test{dof}$ criteria. Should be about 1 for correct fit.
pub reduced_chi2: Scalar,
pub reduced_chi2: Model::Scalar,
/// Type defined by model, containing parameter errors.
///
/// This will usually be the model type itself, but there may be exceptions.
pub errors: OwnedModel,
pub errors: Model::OwnedModel,
/// A parameter covariance matrix. If you don't know what this is, you can safely ignore it.
pub covariance_matrix: GenericMatrix<Scalar, ParamCount, ParamCount>,
pub covariance_matrix: GenericMatrix<Model::Scalar, Model::ParamCount, Model::ParamCount>,
}

#[macro_export]
Expand Down Expand Up @@ -553,17 +556,17 @@ macro_rules! fit_stat {
/// **TIP**: The `FitDimensionsBound` is an unfortunate outcome to strict type system. In case you deal with generic code, just put the `fit!` statement down, and add the bound you seemingly violate - you **should** be good after that.
#[must_use = "Covariance matrix are the only point to call this function specifically"]
pub fn fit_stat<Model, X, Y>(
mut model: Model,
model: &mut Model,
x: X,
y: Y,
minimizer: impl Borrow<LevenbergMarquardt<Model::Scalar>>,
weights: impl Fn(Model::Scalar, Model::Scalar) -> Model::Scalar,
) -> FitStat<Model::Scalar, Model::ParamCount, Model::OwnedModel>
) -> FitStat<Model>
where
Model: FitModelErrors,
Model::Scalar: RealField,
FitterUnit: FitErrBound<Model, X, Y>,
{
let report = FitterUnit::fit(minimizer, &mut model, &x, &y, weights);
let report = FitterUnit::fit(minimizer, model.borrow_mut(), &x, &y, weights);
FitterUnit::produce_stat(model, report, x, y)
}
8 changes: 4 additions & 4 deletions nacfahi/tests/fit_macro.rs
Original file line number Diff line number Diff line change
Expand Up @@ -98,19 +98,19 @@ fn custom_lavmar_weight() {
fn weight_function_respected() {
let x = [0.0];
let y = [1.0];
let model = Constant { c: 0.0 };
let mut model = Constant { c: 0.0 };
let weights = |_: f64, _: f64| panic!("Weight function panic");

let _ = fit!(model, x, y, weights = weights);
let _ = fit!(&mut model, x, y, weights = weights);
}

#[test]
fn default_weights_unused() {
let x = [0.0];
let y = [1.0];
let model = Constant { c: 0.0 };
let mut model = Constant { c: 0.0 };
#[allow(unused)]
let default_weights = |_: f64, _: f64| panic!("Weight function panic");

let _ = fit!(model, x, y, weights = default_weights);
let _ = fit!(&mut model, x, y, weights = default_weights);
}
2 changes: 1 addition & 1 deletion nacfahi/tests/generic_fit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ fn fit_err_generic<const ORDER: usize, Len: ArrayLength>(
y: GenericArray<f64, Len>,
) where
Polynomial<ORDER, f64>: FitModelErrors<Scalar = f64>,
for<'r> FitterUnit: FitErrBound<&'r mut Polynomial<ORDER, f64>, GenericArray<f64, Len>>,
FitterUnit: FitErrBound<Polynomial<ORDER, f64>, GenericArray<f64, Len>>,
{
let mut model = Polynomial {
params: [0.0f64; ORDER],
Expand Down
4 changes: 2 additions & 2 deletions nacfahi/tests/utility_models.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,10 +73,10 @@ fn fit_log() {
// exponential model
let mut expo_model = Exponent { a: 1.0, b: 0.0 };
// expolinear (exponential mapped to linear)
let expolinear = model_map(&mut expo_model, LnMap);
let mut expolinear = model_map(&mut expo_model, LnMap);

// fit!
let report = fit!(expolinear, x, linear_y);
let report = fit!(&mut expolinear, x, linear_y);

assert!(
report.termination.was_successful(),
Expand Down

0 comments on commit d79c999

Please sign in to comment.