Skip to content

Commit

Permalink
Fix covariances update (#322)
Browse files Browse the repository at this point in the history
  • Loading branch information
relf authored Nov 8, 2023
1 parent 39b3eee commit 083fc9a
Showing 1 changed file with 32 additions and 1 deletion.
33 changes: 32 additions & 1 deletion algorithms/linfa-clustering/src/gaussian_mixture/algorithm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -311,8 +311,9 @@ impl<F: Float> GaussianMixtureModel<F> {
)?;
self.means = means;
self.weights = weights / F::cast(n_samples);
self.covariances = covariances;
// GmmCovarType = Full()
self.precisions_chol = Self::compute_precisions_cholesky_full(&covariances)?;
self.precisions_chol = Self::compute_precisions_cholesky_full(&self.covariances)?;
Ok(())
}

Expand Down Expand Up @@ -488,7 +489,9 @@ mod tests {
use ndarray::{array, concatenate, ArrayView1, ArrayView2, Axis};
use ndarray_rand::rand::prelude::ThreadRng;
use ndarray_rand::rand::SeedableRng;
use ndarray_rand::rand_distr::Normal;
use ndarray_rand::rand_distr::{Distribution, StandardNormal};
use ndarray_rand::RandomExt;

#[test]
fn autotraits() {
Expand Down Expand Up @@ -570,6 +573,34 @@ mod tests {
);
}

#[test]
fn test_gmm_covariances() {
let rng = rand_xoshiro::Xoshiro256Plus::seed_from_u64(123);

let data_0 = ndarray::Array::random((500,), Normal::new(0., 0.5).unwrap());
let data_1 = ndarray::Array::random((500,), Normal::new(1., 0.5).unwrap());
let data_2 = ndarray::Array::random((500,), Normal::new(2., 0.5).unwrap());
let data = ndarray::concatenate![ndarray::Axis(0), data_0, data_1, data_2];

let data_2d = data.insert_axis(ndarray::Axis(1)).to_owned();
let dataset = linfa::DatasetBase::from(data_2d);

let gmm = GaussianMixtureModel::params(3)
.n_runs(1)
.tolerance(1e-4)
.with_rng(rng)
.max_n_iterations(500)
.fit(&dataset)
.expect("GMM fit");

// expected results from scikit-learn 1.3.1
let expected = array![[[0.22564062]], [[0.26204446]], [[0.23393885]]];
let expected = Array::from_iter(expected.iter().cloned());
let actual = gmm.covariances();
let actual = Array::from_iter(actual.iter().cloned());
assert_abs_diff_eq!(expected, actual, epsilon = 1e-1);
}

fn function_test_1d(x: &Array2<f64>) -> Array2<f64> {
let mut y = Array2::zeros(x.dim());
Zip::from(&mut y).and(x).for_each(|yi, &xi| {
Expand Down

0 comments on commit 083fc9a

Please sign in to comment.